From 721b8413a048bd4b1b5b492f519a4a6f0df23aa9 Mon Sep 17 00:00:00 2001 From: Jaime Ferrando Huertas Date: Fri, 19 Nov 2021 17:32:30 +0100 Subject: [PATCH 01/59] Added boring model as a ipynb so it can be updated (#10521) Co-authored-by: Carlos Mocholi --- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- pl_examples/bug_report/The_BoringModel.ipynb | 1420 +++++++++++++++++ .../{ => bug_report}/bug_report_model.py | 0 tests/loops/test_loops.py | 3 +- 4 files changed, 1422 insertions(+), 3 deletions(-) create mode 100644 pl_examples/bug_report/The_BoringModel.ipynb rename pl_examples/{ => bug_report}/bug_report_model.py (100%) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 729d258cfcd63..b7f3574f62f99 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -16,7 +16,7 @@ assignees: '' Please reproduce using the BoringModel! You can use the following Colab link: -https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3?usp=sharing +https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report/The_BoringModel.ipynb IMPORTANT: has to be public. or this simple template: diff --git a/pl_examples/bug_report/The_BoringModel.ipynb b/pl_examples/bug_report/The_BoringModel.ipynb new file mode 100644 index 0000000000000..9b061c4283cbf --- /dev/null +++ b/pl_examples/bug_report/The_BoringModel.ipynb @@ -0,0 +1,1420 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "The BoringModel.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d79c1628eded487a974da18a2ea1f98b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_02695b143b764932ba8d0c08a872987e", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_28eb6a3218f64f26abcdff756ffda3ad", + "IPY_MODEL_02cfffd590014c3cbc44ab06c69f9181", + "IPY_MODEL_0d7c50e36cb84f01a57a9d7d8b913393" + ] + } + }, + "02695b143b764932ba8d0c08a872987e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": "row wrap", + "width": "100%", + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": "inline-flex", + "left": null + } + }, + "28eb6a3218f64f26abcdff756ffda3ad": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_6ba2782883ae424dbfc8868224d95da9", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "Epoch 0: 100%", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_baa4aacd0da64cf291fb31c000724573" + } + }, + "02cfffd590014c3cbc44ab06c69f9181": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_7dad3d2feced492a999fb6c91186be50", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 2, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 2, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_ea702a091eb642f7bdda81aa55db8c26" + } + }, + "0d7c50e36cb84f01a57a9d7d8b913393": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_4802a47c6dfb439c83d8b860dce42006", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 2/2 [00:00<00:00, 9.45it/s, loss=-0.618, v_num=0]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_68c87e6a7fcf4e4eab98a941c7c3e867" + } + }, + "6ba2782883ae424dbfc8868224d95da9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "baa4aacd0da64cf291fb31c000724573": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "7dad3d2feced492a999fb6c91186be50": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "ea702a091eb642f7bdda81aa55db8c26": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": "2", + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "4802a47c6dfb439c83d8b860dce42006": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "68c87e6a7fcf4e4eab98a941c7c3e867": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "e6cbe583c2e14986b4faeb27e31f73e1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_672dd78899f944cea7e57f388f3ecb31", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_cd61dda59d104e0a8a8aa9bfc1e55c24", + "IPY_MODEL_1cd72d82332941a6929f88fad5173096", + "IPY_MODEL_92a38638060c4ed5b6d44a2078667e53" + ] + } + }, + "672dd78899f944cea7e57f388f3ecb31": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": "row wrap", + "width": "100%", + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": "inline-flex", + "left": null + } + }, + "cd61dda59d104e0a8a8aa9bfc1e55c24": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_bdc9b06391ee47478efd58cc91ca87ac", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "Validating: 0%", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_ee80657d62c6452d9e9ac199157cdf2a" + } + }, + "1cd72d82332941a6929f88fad5173096": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_eb16b87bcb8d4ca6a83e8b44ea2d1311", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_2a6327dd568241e3acbb6aec1926bd80" + } + }, + "92a38638060c4ed5b6d44a2078667e53": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_a45aba8517e14654850453159780b54a", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 0/1 [00:00<?, ?it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_7fb167222e7143b789b7f40af7cb39dd" + } + }, + "bdc9b06391ee47478efd58cc91ca87ac": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "ee80657d62c6452d9e9ac199157cdf2a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "eb16b87bcb8d4ca6a83e8b44ea2d1311": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "2a6327dd568241e3acbb6aec1926bd80": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": "2", + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "a45aba8517e14654850453159780b54a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "7fb167222e7143b789b7f40af7cb39dd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "abe1c0c4dac94e0e9b894bb69c3ec450": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_23763e19d40d4020b3342a47366e2e19", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_0b7b7da6a6134f0fb26a05adc062ee6f", + "IPY_MODEL_9941635d9d694ba7bce0c7a14c500e5e", + "IPY_MODEL_c7f1407ba92f4dc6ba34bd9cf73fea69" + ] + } + }, + "23763e19d40d4020b3342a47366e2e19": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": "row wrap", + "width": "100%", + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": "inline-flex", + "left": null + } + }, + "0b7b7da6a6134f0fb26a05adc062ee6f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_86f2e0a558cc419e84ed9192ccd3d1b6", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "Testing: 100%", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_141a9c35ade14d9e8645b2c108ab4d66" + } + }, + "9941635d9d694ba7bce0c7a14c500e5e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_833bb79bb1214a3a88795f41b9375690", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 1, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_51b32955ad544803b1d78f07bc685569" + } + }, + "c7f1407ba92f4dc6ba34bd9cf73fea69": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_a6b2764a5fa9444a9d77e8d74c67ef47", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 32/32 [00:00<00:00, 174.23it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_3f0c08c03e284ebb905dae8aca72fffc" + } + }, + "86f2e0a558cc419e84ed9192ccd3d1b6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "141a9c35ade14d9e8645b2c108ab4d66": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "833bb79bb1214a3a88795f41b9375690": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "51b32955ad544803b1d78f07bc685569": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": "2", + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "a6b2764a5fa9444a9d77e8d74c67ef47": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "3f0c08c03e284ebb905dae8aca72fffc": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "rR4_BAUYs3Mb" + }, + "source": [ + "![image.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAQYAAABSCAYAAAC2XXppAAAaWElEQVR4Ae2dh3dU1RbG3x/i0mXvoighgKgI2BtW7KKiSAcFFVBQmgqKil2UJzYUQXl2ERSRpoiIIJhMOiEQkpCQRup+63eSHS73zkxmkmFC2Wetm5m55ZTv7P3tcs5M/iNWDAFDwBDwIfAf32f7aAgYAoaAGDGYEBgChkAAASOGACR2whAwBIwYTAYMAUMggIARQwASO2EIGAJGDCYDhoAhEEDAiCEAiZ0wBAwBIwaTAUPAEAggYMQQgMROGAKGgBGDyYAhYAgEEDBiCEBiJwwBQ8CIwWTAEDAEAggYMQQgsROGgCFgxGAyYAgYAgEEjBgCkNgJQ8AQMGIwGTAEDIEAAkYMAUjshCFgCBgxmAwYAoZAAAEjhgAkdsIQMASMGEwGDAFDIIBAhxJDXW2j7NldL6WFdVJb0xjonJ0wBAyBjkGgQ4ihoUGkvLhONi6vlLlPFMg7j++Qv5aVS0VZvdTXG0F0jChYq4bAPgSSSwyNIpV7GiTzrypZMKNIJvXLlhHdQzKye0ie6pcjC57bKRkbqqW6vF4ajR/2zZK9MwSSjEDSiGFvVaMUZNbI93NL5Lk7cmRUz5AMS0mX4Rxdm47RF2TIs3fmynfv7pYdmTWyt7ohyXBYc4aAIQACB5wYyCMUFdTJmsVlMntIvoy9OFOGdQ3JsNR0GdEttO9IDbUQxNjemfLqsHxZ9UWZ7C6slTrLP5i0GgJJReCAEQOhQNmuOvn7l0r574QCGX95lgyHELqmywg/KXgJovme4d1CMv6KbJk7YYdsWlHhkpQN9UnFxhozBI5YBBJPDJpH2FgtC2cVydM3ZMuIbhkyrEu6DI9GCF5y6NbkPUAio3pkyuQbc2TRrELJ3FQtVXuMHY5YabWBJw2BhBJDTVWjyw38MK9EZt6dJ6MvyGzyEuIghHDhBbmIhy/MlJn35MmS9y3/kDTpsIaOWAQSQgyN9SK7ttXK2q/L5PWR22Vsb2/Y4Mkj+LyC/Ugg2rXUphBkeGpIHuuT5dqgLfIP9bW2fHHESq8N/IAh0C5iII+wp7hetv5WKe9P2ikTroIQMppWG+L1ElD+1FZIpDn/MKpbSJ68Ols+eHqnbF5VKRWl9cLeCCuGgCGQGATaRQx7iurkqzdKZHr/HBnZo42EgKfQTApjLsqQUT0z9q1UhPMimlcvyD+M7pkh027JlR/nlUhZUV1iELFaDAFDoO3LlY0NIqEN1fJwrwjLj+GUOtw5FyaE5ImrsmXuuO0yplemEDK0GmZAECnpbpVjWv8c2fp7pU2nIWAIJAiBNnsMuO5b11bK0C4sP8agyFFIYcLVWfLd3BJZvqC0dULw15MSkonX5sjmXysSBIlVYwgYAu0jht8qY7PufmXmc2pIhqaky8TrcmTVwlIp3VUvC2YUykPn+DY+hXvWey61aTv15pWxEUNjY6Ps2bNH/v77b9mwYUPg4HxmZqaUlpbGLR01NTWSnZ0dqDNcO3ruzz//lFAoJBUVsfU/7k7F8ABj3bx5s6xfv162b98udXVtD8tyc3Pl999/d0dxcbGAd7RC27T722+/SX5+vjS0I1mUk5Pj6vnjjz86FE/w++uvv1xftm7dGm34B+21jiEGwocu6TKpX46sXFQqe6sbHTHMGrhNhpx3YImhvr5e1q1bJ7fddptce+217rj66qvlqquucsf1118vDz30kLz55pvyzz//SG1tbcyTh1LNmDFDrrzySqFOPWjnuuuuc8c111zTcp7rl19+uYwdO1Y2bdoUczuJvnHt2rUycOBAueKKK+T9999vEylqn+bMmeNwvOyyy2TFihWt4sdcgHmvXr1c21VVVVpV3K+vv/660C54o5CtkZK3AQhq586dcT3jfd77HsODfF188cXyyCOPeC8dMu+TTwyQQkq6TLo+R1YvLnOkQL4i79+9Mu7SrLg2Qbk8RJweA2yOwJ5//vnSqVMnOeecc+Tcc8+V8847z71y7owzzpAuXbrI4MGDBaWJ1Yrl5eXJ448/LmeffbZ07ty55aANDs7z6r125plnyr333uusZkdJzc8//+wU6rTTTpNXXnlFsPRtLS+88IIb31lnnSVLliwRvKhoZdWqVXLBBRfIKaecIm+88YZUVrY9VzR9+nShXeYSzy/WecO6z549W1577bVW+xttLHqtrKxMLrnkEjn11FPd3Or5Q+k1ucTQHD48dX2OrFpcJjXVTW4m34X44/ty912JuPMVbSQGhBEBuvHGG2XmzJnueO655+SJJ56Q/v37O+GGMEaMGCHbtm2LaU53794tixcvFgT02WeflWeeecYdDzzwgHTr1s21d8cdd8iUKVNark2dOlU++OADwQXvqAIx4OWcfvrp7SaGWbNmtRDsjz/+2KqiEUY9//zz8tRTT8ny5ctbvT8aRmAOsaekpDgPLFZieOyxx5wxGDRokLTHY9G+VVdXy0svveTGNG/ePD19SL0mlRj4evXUm3Nk9f/2kQJoVZc3yOezimTQOWnxJx/bQAy//vqrs1IIEESA61deXu4OrCUexX333eesT58+feSzzz6LycXEdUUoqM97fPzxx9K3b1/p2rWrU7yioqL9riOMhDhaeI9nE6tg63PRXqmPI1yBGAilIAasZklJiRsv98fjjlM3xADh4h3FQgy0AVZY2WjeBdf27t0brvst5yAG2gXnLVu2uL7H8hyeId4Sr4kgBjqEPDGm9nhALQPzvEE2wCHeefFUEdPbpBLDY30yZfmnpS2egvawuqJBFs8ulqFdQjKUL1l5k4utvW8HMSBAWCqvUtInBHXhwoXO8mDpp02b5rwGLBqkkZ6erl1veWWy0tLSBCUj+VZYWNhy7fPPP3euJUREHBxO+OgDzxBzf/PNN7Jo0SL57rvvXCITRfUWhG3jxo3y/fffu7YyMjJk9erV8sMPP7i8COREQSlIpNIn+oA3Q/9JkHpzJ0oMhFAvv/yya3PZsmWuD/QFtzxWAY+XGCBJcKXv9NU/F+RtCDfAY/78+Y5sGC8YQzwrV65sGYsSQ2pqqrsOlswjzxHW8JwqFN7dTz/9JEuXLnX5AEK6W265Rb799lt3LwlRf19Q9F9++cU9w3NgDpb+ArbgzJhIhIYryBDXqYd+EM6ARaSCDDCeL774Qj788EP5+uuvnRySNKYeMKR/iSpJIwb2JjzeN1N++rg08EUovjW5La1G5jxa0LRaEQ85tIMYUNRwxIBAIHA9e/Z01oe8AZ9vvvlmF4vzDOThLZDChAkTXMJpzJgxLsuv11FKPAbawyL7lQxSYXWCxOUNN9wg3bt3dzkO2r/99tudl0EyTQWV0IYQ5MILL5Q777zThTuEAuRFCGNIotE/lHrUqFFy6aWXurZpn/f0DzLQfigxEJ+TeCX0IXHG/YRcd911l3z66aeCMrVW4iUGViNIzDLW9957r6VPtAP5TZ48WUjYMjbyM/Rn5MiRDhf6yPi1X15iwBPs16+fC2sICbl32LBhTrnwxFBEzukck/dhvOSeIBZCSCVYHXNWVpYLuS666CLhIHkMvii21xsDe2SFurkerpBPoR3q6d27tyMl5h9y9BdWW8jdgBN9BAeeffDBB4XwR/uCDCaqJI0Y1AsgwfjJs4Wya9v+SSnIobigVr55u9h912JISoyeQzuIQT0GP5iEEySjmAAEhfckFp988knnbpO9x1prQbHJEzBBPXr0cCsauJJaohEDVh3PAAFHgJl4svQDBgxwgodbzjkEAHJCqBEUyIp4GoHGq8Ha3X333U65yFeQRIQEcK3J+KPctIGwQgAQ1SeffOIIRImB8dIe93MvpMR4qIOxQQ7ecen4vK/xEoMSsD/5SJ9QLjCBFMjygwHEyTnGzfhRTi8xMAa9ritM1MN5xsHzeA4cKP/999/viJhrEAVtQIyMg7nBw0AeduzY4aw/JEV4CSkTeik2eCVasNxgT/KRkFQLcgJpUxc5CGRLQ0zqot+Mk7yLFpSdvBTXmDf6P2TIEDc3zBXj4mCeCJ8SVZJLDGyESg253Y0fTimUwtwa2W+Vm99wKGmQpfNKZGwfdkDGQA7tIAYUbvjw4W69GdcU6wX7M2lYYCYCq8PKBJOKO4rFglAmTZrUEhIwkSxLcT/Kyf3qsjJR0YgBT2D06NFOyBFMEpbkQLBouPN4IUw6gvH000+7kAbF19UPlMb7DHsB6CdLdrjHCNJHH33kVj1wa7HKjIkxEFOjILih5BiUMFhy5F4OFASLRgz+8MMPt5okjZcYCBNQMurHirKfA8XCuiPw9JM616xZ48Ik7h8/frw7r8SgoRYeA89wnpUeyIUlZ+aDsUImKCAEh1X/999/XciBEoMVc8f+EtxzyJc5JCxAHrD84I9BIIT56quvZNy4ca5OcIOEdDVHiYG2IB4t1Es9ENI777zj6iIMIDQgCU49hHO0R6F9wljOMSa8IMIXCICwAg9DPQgI/9AlBvIFXUNuU9TYXpmy6MWiwLcjWaHYsKxCxvaNfWs0vxcZ6wYnXD5NPiIoKOOtt97qmJpXlAaQmVSsJEqim49w8xBYJgoLjavL5H355ZdOsWBwVjhUQFQgIhED1p+4Hw+EvuCRoKgII2EDFguCQJAQmptuuskpsRID51Bo7uEZxsZzKAheBEqFJ0PsqtdYr2f8tIuQ01f1GBgzeyoYp95Pf/BeEEysWWvuanuJgfAGoYeMGB/EzXgVE17pH33BWvs9Bs7RVzDXpC7P4JXhfmPFmSOwBSu8DTBkTrHEeESc18QvBoG2uM5qFffzLOcJLSATSAXiRh4okYgBo8OY6AOrFdTDQT8hCPpHXXh3FOaGfRnMC3IJWSkOzA8hJaEfdR76xAA5NHsO744rEPYwaKmvFflnTaW8MDBPRup9Byj5iJXCujARJ598sjtwZ5k0FApXGqXC7dMCQeB+IwQ8/9Zbb7mYmLgexSYGRNlUqPS5SMRADEsdCATuKZYMofQWhIZ7UHSUhWQaioK1QlgRTLwELQg2ysTYcHuVvPQ6RKaCjXDz2UsMJB+9xEb7WDnGh1VrbSdfIogBvBgv48PDoQ/eAkbkeSIRA1aUcXvnAWuN646SYYUZOwXSxDtUYvC35SUGjIE3lKJ+rDveDsqJN0OJRgzImxKDu7n5D+0yX/QPnCmQMB4h55hv73i4zjzieXD9sCGGUT1C8u2cfRtpIIUtayrl+QF5TV+/jvX7F+0IJQCd+BNhfvHFF91BcpDlSbLSWFfvZKBE7FAk9kYoiUe5j89YKSYPVveXSMSAMNAegow7S8bZX7AMKAfhBB4MxOQlBnYsFhQUtDxGn3GbIQbCiFjcSy8xvPrqq265UivEghPqJJsYIEsEHsvqV1b6RijHGMN5DOAJEXjnjs/qCULkSgy7du1qlRiYXxQaj8FLDPTj3XffbSEG8iWUSMQAzhAIROLf30BfIR6uQQzIGiSMEVJicJV7/uA90P5hRQyP9s2UTSuadrk5UlhbKc/fm+d+KDamb1eqJ9EOYkCAiFexGnqgWFhyFJLJ8RfiWZQZYiBxxOYYBA5rzn4FWNxfIhEDbRBf4n2gCFhr2vcWvAGsI0qAZWOlQUMJrBwxrJcYEHjyD5Ae+RDyFP4+kUhFUSA5vCDNMSBgkYgBAY3XY6BehDda8ecYICLyPZAgSgTJeb026mKJECWNRAx4fIwtFmJgiRhcUXxCCT8JgWc0Ypg7d25CiAEviDEpMfAZWcOTpG8QIPPmLXh2zP/hE0p0C8nkG3OlKJ9fYBLZAinc1wZSaA432ppjgBhQunAE4J0A73smjCQlwoQVhRQgCWI91vvDlUjEwL0kk1TwsPCQC8IKaSAIfGeDHATCMXToUOcBRCMG6iRDzsoCAsMzKBpkB0EQv0McxM0ks/Bw/B6DN5RQjyFeYkBp8W5YrycW9x7kCGgDxQ1HDFhcPCG8MJZuSa6SCIUQwIv+E2pECiXiIQaIGKWjLhK2uO+QmZJKa8SQKI8BufJ6DNoHPCNCXWSNZWbGz/yTJCUkwqDQ90M/lEht+gczr43Ml9rqxqbwoa2k0A5iUCsNMTAp8RQ23ZChRllRANx8Mup+N1PrVGJgEv37GEhm8Sx1oAgIKQLA+j1JR5JPeAaEGuQgsGgaStC+32OgTZJSrPXTN8jvnnvucQTIWEls4U0Q5xIf4ykpMVAfHoOfGFiNgBhIfsaSY6B+BJn7Wfpjvd170B/Ii7FADHgHeCvgABFB1Hg6+n0DyAHypC6Uh7rBSvHyL1eCcziPAQ8KLL2hBAQMVlhq+g2RkgSGTJELiIHlQkg2UihB3+mLP5RQj07lAJy5j/v9oYQSA9fwzNTTAm8I68QTT3RLyYyfEJY+MbfMmRqoWMJG7Utrr8ldrmxW5FHnZ8iXrxVL2rpKmXFvbvzhg4YRbSQGJggwmQRCgXiJAcuL4BJKIDCsZrDTLZLngUJj4WiPfIaucjA5PIPVJsdBOEICFCGlXl5RYLLmuKysgVOwnCQEjznmGGf5ISpvYTzkPkhCMk4sDgKEoCJgWBiUgXCCe9l9h4dBfXxvAbLQQl9x50866STnubD8F61g3ekz99NWuOOoo45yxMeSIeEGpHP00Ue7UEqxgTRY7UEBUCatj7FAlqwm+UMJyPr44493OLNSo1af/pKMBAvamThxYkuOgWtYYbw1sKf+4447zi3xopwQA0uRxx57rCNtP/lDZtTJdcZCwQuB6DmPZ6YFnLmP8yQNvYV5wChwDQ9RiYHzEI56UOCpWECoyMZB5zFsWV0pAzulybCUkPtXc7qJKeor+xguzpT5U3fKzAF5MjSlafky6jNeIvC+d1/fDsm4K7Jk44p9G4q8gPvfAzTMinCQLFywYMF+AuS/P9JnFASBQThZHvQqk/8ZQg88APIZrFogbN4COZAEQ3BQLCwjqyIoJISBVVWryHNYdPpN/xEwf16CexgnBEJowhIkqxdYap6BqHDnVfgYC+3iGbCs5xV++op1e/TRR93uO+8KiHcM+p7x4f6CL5Y33EEoQK6E8IZQA5eYcyznQbpKsFhzQhCSsuz8o4+47pAeG4dQCJRJsWFvAeERnhFek9ZD3+i3tsN91K0FrAgD3377bTdH4MC+Cc6DEfsFmAu+ZEefvYVlRiw4+QklTe7hGZ0fvR+y4j7uZ669BRJjfFzDo6FtCmPgGuQPtsgD44CQIAzagPAJJVrz5rzttfa+zR4DO5N2ba9r2cb84NlpMf9oCx7DuMuyZERqRszPBIgjNSSDzm760tX8aeykjJ7k8gLBxAE0wqJC5b3e2nueR7jwAlgCY9+61zr5n8c91vawkl6B1Xs5hxBCECgtm294JQHlFWLu5zNkQFJOY3Wtx/tKn2ibXAX1ET8zZr9wY52pi/uI771joV+0gaJxj5KJtx3ve57HsyEhyv2RDu6jbohHdwNCSHqO3AoKjgcDdlxjzNxPP3G38YIIMdTLoE5tz99PPmu/whEpY+Z5+s1YwY1Cf2gfggI7PnsLdTFPkDA46jPkiajLG5Zxnfu4P1wfaIdrtKOFkItNbhgWFJ86kFnmkHETGuJdkqTk+USVthODiLCNuTCvTlYuLpMZd+c5RR0U4y8wuZWHWJckvV5Ct5A81DldBndOk5kDcmXlojIpzmdjT3yQMMF6xPdkUwyPBYWpEVCEprWibfkFy/+c975o9+p9/ufDfdZ79TXaPfFei3a/thfu1fucXtdzKCnWE9efRCJYYy2xuIRw7DIlj4DbT2jmXXXx16V18hrtmt4X7h49x2u4otf918LdH+lentVr3uf48hghJZvnCCcwSIRFhD+ETuQZkEM8Gj8Z+vsTz+d2EYNrqFGkrlokd3O1fPlmsfsJ+QfOTJMh/BZk9zb+FqSPCNRb4NuXD3ZKk4nXZss3b5VIfjr/+BYFj2fI7bsX8PmGHclLYmNc5UROSPt6d/g8TdadDT+aEyGfw+4/fvEKwjjhhBPcZ7ajexXp8EGgaSR4HnhF5CbIUeEZgAO5BeQPfMi3QBSJLO0nhube8FN9lbsbJW1dlXwweaeM7pkpEITzDNpJEOQwCFXGXJQpH0/dKaE/qtz/kvDumkwkKNHqwtVnSy0JMSaH7yVYSTwChAysDBBDk5RFKUi64SUQwrFio3mAxLd+8NQI6RFCkHdglUYTyOCAt4AXQb4CvBJZEkYM2qmGOpGSgjr5c+keeXXYdhlybpNSx/3LTHgNzXkEfjT29VH58ufSctldyA+YaGvJfyW2YymMRBmCS1xq5cAgQAKOOJ1kI8lKkqZ8x0PX8o8UTw1ywHNgFYnkMDgQYvCdF3IWBwKHhBODikhdjUhBVq38/MlumdY/VwZ1SnO5gVjDi8HnprtfjJ5+e46sWFgqu/JqpY78YhLDBh2L/xWB1cN/zT4nFgGUgpwDFhFS5uDz4Rw+REKQcZN4VhyQwQOFwwEjBje4BpGqikbJ3bJX/vdKkVuJILzAA4hEEC5s6JQm4y/Pcv/lyuURKpKbR4g0MXbeEDhSEDiwxNCMIq5/eUm9/LuuSuY9uUNG9AjJA532X94cTh6hU5qMPj/D/U/K0PrmPMJB4CEcKcJg4zQEFIGkEIM2Rv6heHudrF+yR14elC+DO6e7JU7+yQz/T2L24G2yflm5lHZwHkH7a6+GwJGKQFKJQUGu3SuyI6dWfpq/W6bcnC1Tb8lxPxK7a3ut8EMtB0MeQftqr4bAkYhAhxCDA7pBpLqyUXZk18r2jBrZW2GEcCQKoI354ESg44ihGQ82JyVzg9LBOQ3WK0Pg4EKgw4nh4ILDemMIGAIgYMRgcmAIGAIBBIwYApDYCUPAEDBiMBkwBAyBAAJGDAFI7IQhYAgYMZgMGAKGQAABI4YAJHbCEDAEjBhMBgwBQyCAgBFDABI7YQgYAkYMJgOGgCEQQMCIIQCJnTAEDAEjBpMBQ8AQCCBgxBCAxE4YAoaAEYPJgCFgCAQQMGIIQGInDAFDwIjBZMAQMAQCCBgxBCCxE4aAIWDEYDJgCBgCAQSMGAKQ2AlDwBD4P9CuROTFaWXrAAAAAElFTkSuQmCC)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i7XbLCXGkll9" + }, + "source": [ + "# The Boring Model\n", + "Replicate a bug you experience, using this model.\n", + "\n", + "[Remember! we're always available for support on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2LODD6w9ixlT" + }, + "source": [ + "---\n", + "## Setup env" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zK7-Gg69kMnG" + }, + "source": [ + "%%capture\n", + "! pip install pytorch-lightning --upgrade" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WvuSN5jEbY8P" + }, + "source": [ + "---\n", + "## Deps" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "w4_TYnt_keJi" + }, + "source": [ + "import os\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "from pytorch_lightning import LightningModule, Trainer" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XrJDukwPtUnS" + }, + "source": [ + "---\n", + "## Data\n", + "Random data is best for debugging. If you needs special tensor shapes or batch compositions or dataloaders, modify as needed" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hvgTiaZpkvwS" + }, + "source": [ + "class RandomDataset(Dataset):\n", + " def __init__(self, size, num_samples):\n", + " self.len = num_samples\n", + " self.data = torch.randn(num_samples, size)\n", + "\n", + " def __getitem__(self, index):\n", + " return self.data[index]\n", + "\n", + " def __len__(self):\n", + " return self.len\n" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "sxVlWjGhl02D" + }, + "source": [ + "num_samples = 10000" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "V7ELesz1kVQo" + }, + "source": [ + "class BoringModel(LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.layer = torch.nn.Linear(32, 2)\n", + "\n", + " def forward(self, x):\n", + " return self.layer(x)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " loss = self(batch).sum()\n", + " self.log(\"train_loss\", loss)\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " loss = self(batch).sum()\n", + " self.log(\"valid_loss\", loss)\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " loss = self(batch).sum()\n", + " self.log(\"test_loss\", loss)\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.layer.parameters(), lr=0.1)" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ubvW3LGSupmt" + }, + "source": [ + "---\n", + "## Define the test\n", + "NOTE: in colab, set progress_bar_refresh_rate high or the screen will freeze because of the rapid tqdm update speed." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4Dk6Ykv8lI7X" + }, + "source": [ + "def run():\n", + " train_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", + " val_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", + " test_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", + "\n", + " model = BoringModel()\n", + " trainer = Trainer(\n", + " default_root_dir=os.getcwd(),\n", + " limit_train_batches=1,\n", + " limit_val_batches=1,\n", + " num_sanity_val_steps=0,\n", + " max_epochs=1,\n", + " enable_model_summary=False,\n", + " )\n", + " trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)\n", + " trainer.test(model, dataloaders=test_data)" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4dPfTZVgmgxz" + }, + "source": [ + "---\n", + "## Run Test" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AAtq1hwSmjKe", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 272, + "referenced_widgets": [ + "d79c1628eded487a974da18a2ea1f98b", + "02695b143b764932ba8d0c08a872987e", + "28eb6a3218f64f26abcdff756ffda3ad", + "02cfffd590014c3cbc44ab06c69f9181", + "0d7c50e36cb84f01a57a9d7d8b913393", + "6ba2782883ae424dbfc8868224d95da9", + "baa4aacd0da64cf291fb31c000724573", + "7dad3d2feced492a999fb6c91186be50", + "ea702a091eb642f7bdda81aa55db8c26", + "4802a47c6dfb439c83d8b860dce42006", + "68c87e6a7fcf4e4eab98a941c7c3e867", + "e6cbe583c2e14986b4faeb27e31f73e1", + "672dd78899f944cea7e57f388f3ecb31", + "cd61dda59d104e0a8a8aa9bfc1e55c24", + "1cd72d82332941a6929f88fad5173096", + "92a38638060c4ed5b6d44a2078667e53", + "bdc9b06391ee47478efd58cc91ca87ac", + "ee80657d62c6452d9e9ac199157cdf2a", + "eb16b87bcb8d4ca6a83e8b44ea2d1311", + "2a6327dd568241e3acbb6aec1926bd80", + "a45aba8517e14654850453159780b54a", + "7fb167222e7143b789b7f40af7cb39dd", + "abe1c0c4dac94e0e9b894bb69c3ec450", + "23763e19d40d4020b3342a47366e2e19", + "0b7b7da6a6134f0fb26a05adc062ee6f", + "9941635d9d694ba7bce0c7a14c500e5e", + "c7f1407ba92f4dc6ba34bd9cf73fea69", + "86f2e0a558cc419e84ed9192ccd3d1b6", + "141a9c35ade14d9e8645b2c108ab4d66", + "833bb79bb1214a3a88795f41b9375690", + "51b32955ad544803b1d78f07bc685569", + "a6b2764a5fa9444a9d77e8d74c67ef47", + "3f0c08c03e284ebb905dae8aca72fffc" + ] + }, + "outputId": "59e8bcf2-a944-46fc-a771-e7cbbbe4727d" + }, + "source": [ + "run()" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "GPU available: True, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py:1567: UserWarning: GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.\n", + " \"GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.\"\n", + "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/data_loading.py:395: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + " f\"The number of training samples ({self.num_training_batches}) is smaller than the logging interval\"\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d79c1628eded487a974da18a2ea1f98b", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6cbe583c2e14986b4faeb27e31f73e1", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "Validating: 0it [00:00, ?it/s]" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "abe1c0c4dac94e0e9b894bb69c3ec450", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--------------------------------------------------------------------------------\n", + "DATALOADER:0 TEST RESULTS\n", + "{'test_loss': -1.676544427871704}\n", + "--------------------------------------------------------------------------------\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Flyi--SpvsJN" + }, + "source": [ + "---\n", + "## Environment\n", + "Run this to get the environment details" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0-yvGFRoaDSi" + }, + "source": [ + "%%capture\n", + "! wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/requirements/collect_env_details.py" + ], + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "quj4LUDgmFvj", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "bb7a5f74-d52c-4927-b12a-49589aed7dcb" + }, + "source": [ + "! python collect_env_details.py" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "* CUDA:\n", + "\t- GPU:\n", + "\t\t- Tesla K80\n", + "\t- available: True\n", + "\t- version: 11.1\n", + "* Packages:\n", + "\t- numpy: 1.19.5\n", + "\t- pyTorch_debug: False\n", + "\t- pyTorch_version: 1.10.0+cu111\n", + "\t- pytorch-lightning: 1.5.1\n", + "\t- tqdm: 4.62.3\n", + "* System:\n", + "\t- OS: Linux\n", + "\t- architecture:\n", + "\t\t- 64bit\n", + "\t\t- \n", + "\t- processor: x86_64\n", + "\t- python: 3.7.12\n", + "\t- version: #1 SMP Sat Jun 5 09:50:34 PDT 2021\n" + ] + } + ] + } + ] +} diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py similarity index 100% rename from pl_examples/bug_report_model.py rename to pl_examples/bug_report/bug_report_model.py diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 6bd7db1aeff8d..3c8912e145305 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,12 +22,11 @@ import torch from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader -from pl_examples.bug_report_model import RandomDataset from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf From ec27313be242b4131354490e9dbed365ed7b9be9 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 19 Nov 2021 22:18:26 +0530 Subject: [PATCH 02/59] Fix batch size extraction when set by the user in `LightningModule.log` (#10408) Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 3 ++ .../loops/epoch/training_epoch_loop.py | 6 +-- .../logger_connector/logger_connector.py | 23 ++++----- .../connectors/logger_connector/result.py | 51 +++++++++---------- pytorch_lightning/utilities/data.py | 5 +- tests/deprecated_api/__init__.py | 28 +++++++--- tests/loops/test_loop_state_dict.py | 13 +++-- .../logging_/test_train_loop_logging.py | 38 ++++++++++++-- tests/utilities/test_data.py | 7 ++- 9 files changed, 112 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65cf84ad092c6..9e52442c7b356 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610)) +- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) + + - diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 21d89a8be8b52..8ddca3ad505e8 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -161,9 +161,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.batch_progress.increment_ready() - # cache the batch size value to avoid extracting it again after the batch loop runs as the value will be - # different if tbptt is enabled - batch_size = self.trainer.logger_connector.on_batch_start(batch_idx, batch) + self.trainer.logger_connector.on_batch_start(batch_idx, batch) if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") @@ -194,8 +192,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None: with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, batch_idx) - self.trainer._results.batch_size = batch_size - self.batch_progress.increment_processed() # update non-plateau LR schedulers diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7a457ebf4fcc1..d970d98c602bc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -214,7 +214,6 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]: def on_train_split_start(self, split_idx: int, split_batch: Any) -> None: self._split_idx = split_idx - self.on_new_batch(split_batch) def update_train_step_metrics(self) -> None: if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization: @@ -257,28 +256,23 @@ def _log_gpus_metrics(self) -> None: Utilities and properties """ - def on_new_batch(self, batch: Any) -> int: - # when the user requests `dataloader_iter`, we can't track the batch_size - # and this is left to user responsibility. - if not isinstance(batch, pl.utilities.fetching.StepFuncDataLoaderIter): - assert self.trainer._results is not None - return self.trainer._results.extract_batch_size(batch) - return 1 - def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self, batch_idx: int, batch: Any) -> int: + def on_batch_start(self, batch_idx: int, batch: Any) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False - return self.on_new_batch(batch) + + assert self.trainer._results is not None + # attach reference to the new batch and remove the cached batch_size + self.trainer._results.batch = batch + self.trainer._results.batch_size = None def epoch_end_reached(self) -> None: self._epoch_end_reached = True self._batch_idx = None self._split_idx = None assert self.trainer._results is not None - self.trainer._results.batch_size = 1 def on_epoch_end(self) -> None: assert self._epoch_end_reached @@ -295,6 +289,11 @@ def on_batch_end(self) -> None: self._callback_metrics.update(metrics["callback"]) self._logged_metrics.update(metrics["log"]) + assert self.trainer._results is not None + # drop the reference to current batch and batch_size + self.trainer._results.batch = None + self.trainer._results.batch_size = None + def should_reset_tensors(self, fx: str) -> bool: is_different_fx = self._current_fx != fx if self._split_idx is None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index fc2e48bb17133..1b82baf0440c9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -212,7 +212,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) - def update(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None: + def update(self, value: _IN_METRIC, batch_size: int) -> None: if self.is_tensor: if not torch.is_floating_point(value): dtype = torch.get_default_dtype() @@ -259,7 +259,7 @@ def reset(self) -> None: self.value.reset() self.has_reset = True - def forward(self, value: _IN_METRIC, batch_size: torch.Tensor) -> None: + def forward(self, value: _IN_METRIC, batch_size: int) -> None: if self.meta.enable_graph: with torch.no_grad(): self.update(value, batch_size) @@ -385,8 +385,9 @@ class ResultCollection(dict): def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None: super().__init__() self.training = training - self._batch_size = torch.tensor(1, device=device) self.device: Optional[Union[str, torch.device]] = device + self.batch: Optional[Any] = None + self.batch_size: Optional[int] = None @property def result_metrics(self) -> List[ResultMetric]: @@ -399,14 +400,23 @@ def append_fn(v: ResultMetric) -> None: apply_to_collection(list(self.values()), ResultMetric, append_fn) return o - @property - def batch_size(self) -> torch.Tensor: - # performance: cache the `batch_size` tensor instead of re-creating it - return self._batch_size + def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int: + # check if we have extracted the batch size already + if batch_size is None: + batch_size = self.batch_size + + if batch_size is not None: + return batch_size - @batch_size.setter - def batch_size(self, value: int) -> None: - self._batch_size = torch.tensor(value, device=self.device) + batch_size = 1 + if self.batch is not None and meta.on_epoch and meta.is_mean_reduction: + try: + batch_size = extract_batch_size(self.batch) + self.batch_size = batch_size + except RecursionError: + pass + + return batch_size def log( self, @@ -467,10 +477,8 @@ def log( f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed" ) - if batch_size is not None: - self.batch_size = batch_size - - self.update_metrics(key, value) + batch_size = self._extract_batch_size(batch_size, meta) + self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: """Create one ResultMetric object per value. @@ -487,10 +495,10 @@ def fn(v: _IN_METRIC) -> ResultMetric: value = ResultMetricCollection(value) self[key] = value - def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None: - def fn(result_metric: ResultMetric, v: ResultMetric) -> None: + def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None: + def fn(result_metric: ResultMetric, v: torch.Tensor) -> None: # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` - result_metric.forward(v.to(self.device), self.batch_size) + result_metric.forward(v.to(self.device), batch_size) result_metric.has_reset = False apply_to_collections(self[key], value, ResultMetric, fn) @@ -584,19 +592,10 @@ def fn(item: ResultMetric) -> None: apply_to_collection(self, ResultMetric, fn) - def extract_batch_size(self, batch: Any) -> int: - try: - batch_size = extract_batch_size(batch) - except RecursionError: - batch_size = 1 - self.batch_size = batch_size # the setter converts it to `Tensor` - return batch_size - def to(self, *args: Any, **kwargs: Any) -> "ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) - self._batch_size = self._batch_size.to(*args, **kwargs) if "device" in kwargs: self.device = kwargs["device"] return self diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index aeda637b7d35b..bbe41217f1346 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -36,7 +36,10 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]: if isinstance(batch, torch.Tensor): - yield batch.size(0) + if batch.ndim == 0: + yield 1 + else: + yield batch.size(0) elif isinstance(batch, str): yield len(batch) elif isinstance(batch, (Iterable, Mapping)): diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index 1026981f75307..91c7ef1c1f880 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -14,7 +14,7 @@ """Test deprecated functionality which will be removed in vX.Y.Z.""" import sys from contextlib import contextmanager -from typing import Optional +from typing import Optional, Type import pytest @@ -26,14 +26,28 @@ def _soft_unimport_module(str_module): @contextmanager -def no_deprecated_call(match: Optional[str] = None): +def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None): with pytest.warns(None) as record: yield + + if match is None: try: - w = record.pop(DeprecationWarning) - if match is not None and match not in str(w.message): - return + w = record.pop(expected_warning) except AssertionError: - # no DeprecationWarning raised + # no warning raised + return + else: + for w in record.list: + if w.category is expected_warning and match in w.message.args[0]: + break + else: return - raise AssertionError(f"`DeprecationWarning` was raised: {w}") + + msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`" + raise AssertionError(f"{msg} was raised: {w}") + + +@contextmanager +def no_deprecated_call(match: Optional[str] = None): + with no_warning_call(expected_warning=DeprecationWarning, match=match): + yield diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 717d625f6c44e..72eeb197e9e57 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -14,7 +14,6 @@ from unittest.mock import Mock import pytest -import torch from pytorch_lightning.loops import FitLoop from pytorch_lightning.trainer.trainer import Trainer @@ -80,14 +79,16 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "epoch_loop.val_loop._results": { + "batch": None, + "batch_size": None, "training": False, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, "epoch_loop._results": { + "batch": None, + "batch_size": None, "training": True, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, @@ -106,8 +107,9 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "_results": { + "batch": None, + "batch_size": None, "training": False, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, @@ -122,8 +124,9 @@ def test_loops_state_dict_structure(): "is_last_batch": False, }, "_results": { + "batch": None, + "batch_size": None, "training": False, - "_batch_size": torch.tensor(1), "device": None, "items": {}, }, diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 22a1a2c90d756..0ec61358d9408 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -27,6 +27,7 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset from tests.helpers.runif import RunIf @@ -715,19 +716,15 @@ def on_validation_epoch_end(self): assert all(v == 3 for v in self.trainer.callback_metrics.values()) def on_train_batch_start(self, batch, batch_idx): - assert self.trainer._results.batch_size == 2 self.log("on_train_batch_start", 1.0, reduce_fx="sum") def on_train_batch_end(self, outputs, batch, batch_idx): - assert self.trainer._results.batch_size == 2 self.log("on_train_batch_end", 1.0, reduce_fx="sum") def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - assert self.trainer._results.batch_size == 2 self.log("on_validation_batch_start", 1.0, reduce_fx="sum") def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - assert self.trainer._results.batch_size == 2 self.log("on_validation_batch_end", 1.0, reduce_fx="sum") def training_epoch_end(self, *_) -> None: @@ -749,3 +746,36 @@ def validation_epoch_end(self, *_) -> None: train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + + +def test_no_batch_size_extraction_with_specifying_explictly(tmpdir): + batch_size = BoringModel().train_dataloader().batch_size + 1 + fast_dev_run = 2 + log_val = 7 + + class CustomBoringModel(BoringModel): + def on_before_batch_transfer(self, batch, *args, **kwargs): + # This is an ambiguous batch which have multiple potential batch sizes + if self.trainer.training: + batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch} + return batch + + def training_step(self, batch, batch_idx): + self.log("step_log_val", log_val, on_epoch=False) + self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True) + self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum") + return super().training_step(batch["batch2"], batch_idx) + + def on_train_epoch_end(self, *args, **kwargs): + results = self.trainer._results + assert results["training_step.step_log_val"].value == log_val + assert results["training_step.step_log_val"].cumulated_batch_size == 0 + assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run + assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run + assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run + + model = CustomBoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run) + + with no_warning_call(match="Trying to infer the `batch_size`"): + trainer.fit(model) diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index acbe645515f55..f4c61cda64f5d 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -12,6 +12,7 @@ warning_cache, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset @@ -19,9 +20,8 @@ def test_extract_batch_size(): """Tests the behavior of extracting the batch size.""" def _check_warning_not_raised(data, expected): - with pytest.warns(None) as record: + with no_warning_call(match="Trying to infer the `batch_size`"): assert extract_batch_size(data) == expected - assert len(record) == 0 def _check_warning_raised(data, expected): with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."): @@ -43,6 +43,9 @@ def _check_warning_raised(data, expected): batch = {"test": [{"test": [torch.zeros(11, 10)]}]} _check_warning_not_raised(batch, 11) + batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])} + _check_warning_raised(batch, 1) + batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]} _check_warning_raised(batch, 11) From 17a8290ca7ac940cdd67a389f6fae0eb979bfebd Mon Sep 17 00:00:00 2001 From: Aki Nitta Date: Sat, 20 Nov 2021 01:49:07 +0900 Subject: [PATCH 03/59] Use new GitHub labels (#10552) --- .github/stale.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/stale.yml b/.github/stale.yml index 1ac5e7448c9ff..51b57c079879d 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -8,8 +8,8 @@ issues: daysUntilClose: 7 # Issues with these labels will never be considered stale exemptLabels: - - p0 - - p1 + - "priority: 0" + - "priority: 1" # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it hasn't had any recent activity. From ff8ac6e2e11adcca9cc773f6f038ff8cf9d37440 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 19 Nov 2021 22:22:24 +0530 Subject: [PATCH 04/59] Make `_get_nvidia_gpu_stats` public (#10406) --- pytorch_lightning/accelerators/gpu.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index b84e53c0426fe..62af5f27dcc1c 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -73,7 +73,7 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: """ if _TORCH_GREATER_EQUAL_1_8: return torch.cuda.memory_stats(device) - return _get_nvidia_gpu_stats(device) + return get_nvidia_gpu_stats(device) def teardown(self) -> None: super().teardown() @@ -85,7 +85,7 @@ def auto_device_count() -> int: return torch.cuda.device_count() -def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: +def get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: @@ -98,6 +98,10 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: FileNotFoundError: If nvidia-smi installation not found """ + nvidia_smi_path = shutil.which("nvidia-smi") + if nvidia_smi_path is None: + raise FileNotFoundError("nvidia-smi: command not found") + gpu_stat_metrics = [ ("utilization.gpu", "%"), ("memory.used", "MB"), @@ -111,9 +115,6 @@ def _get_nvidia_gpu_stats(device: torch.device) -> Dict[str, float]: gpu_query = ",".join(gpu_stat_keys) gpu_id = _get_gpu_id(device.index) - nvidia_smi_path = shutil.which("nvidia-smi") - if nvidia_smi_path is None: - raise FileNotFoundError("nvidia-smi: command not found") result = subprocess.run( [nvidia_smi_path, f"--query-gpu={gpu_query}", "--format=csv,nounits,noheader", f"--id={gpu_id}"], encoding="utf-8", @@ -130,10 +131,7 @@ def _to_float(x: str) -> float: s = result.stdout.strip() stats = [_to_float(x) for x in s.split(", ")] - - gpu_stats = {} - for i, (x, unit) in enumerate(gpu_stat_metrics): - gpu_stats[f"{x} ({unit})"] = stats[i] + gpu_stats = {f"{x} ({unit})": stat for (x, unit), stat in zip(gpu_stat_metrics, stats)} return gpu_stats From 5d748e560b72ba3d3c93de683548a67a8426e29c Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Fri, 19 Nov 2021 18:03:14 +0100 Subject: [PATCH 05/59] LightningCLI changes for jsonargparse>=4.0.0 (#10426) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: thomas chaton --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/cli.py | 27 +++++++++++++------------- pytorch_lightning/utilities/imports.py | 2 +- requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 4 ++-- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e52442c7b356..0f5f3644b6846 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426)) + + - Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 6ed485257feaa..b08ad7265ca60 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -14,7 +14,6 @@ import inspect import os import sys -from argparse import Namespace from types import MethodType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from unittest import mock @@ -32,13 +31,12 @@ from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple if _JSONARGPARSE_AVAILABLE: - from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode - from jsonargparse.actions import _ActionSubCommands + from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode from jsonargparse.optionals import import_docstring_parse set_config_read_mode(fsspec_enabled=True) else: - ArgumentParser = object + ArgumentParser = Namespace = object class _Registry(dict): @@ -100,7 +98,7 @@ class LightningArgumentParser(ArgumentParser): # use class attribute because `parse_args` is only called on the main parser _choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {} - def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. For full details of accepted arguments see `ArgumentParser.__init__ @@ -109,9 +107,9 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non if not _JSONARGPARSE_AVAILABLE: raise ModuleNotFoundError( "`jsonargparse` is not installed but it is required for the CLI." - " Install it with `pip install jsonargparse[signatures]`." + " Install it with `pip install -U jsonargparse[signatures]`." ) - super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) + super().__init__(*args, **kwargs) self.add_argument( "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) @@ -363,7 +361,7 @@ class SaveConfigCallback(Callback): def __init__( self, parser: LightningArgumentParser, - config: Union[Namespace, Dict[str, Any]], + config: Namespace, config_filename: str, overwrite: bool = False, multifile: bool = False, @@ -671,8 +669,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed - action_subcommands = [a for a in self.parser._actions if isinstance(a, _ActionSubCommands)] - action_subcommand = action_subcommands[0] + action_subcommand = self.parser._subcommands_action return action_subcommand._name_parser_map[subcommand] def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: @@ -772,12 +769,16 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: return fn_kwargs -def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]: +def _global_add_class_path( + class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None +) -> Dict[str, Any]: + if isinstance(init_args, Namespace): + init_args = init_args.as_dict() return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}} -def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]: - def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]: +def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: + def add_class_path(init_args: Namespace) -> Dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 5db24fe0f5cff..aa6349b5d677a 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -85,7 +85,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra") _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") -_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") +_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") and _compare_version("jsonargparse", operator.ge, "4.0.0") _KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available() _NEPTUNE_AVAILABLE = _module_available("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") diff --git a/requirements/extra.txt b/requirements/extra.txt index 4aea9dad9cfad..6abf3089b8506 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -5,6 +5,6 @@ horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already insta torchtext>=0.8.* omegaconf>=2.0.5 hydra-core>=1.0.5 -jsonargparse[signatures]>=3.19.3 +jsonargparse[signatures]>=4.0.0 gcsfs>=2021.5.0 rich>=10.2.2 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 5f824d1beed0b..1d6146f16e3e2 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -348,7 +348,7 @@ def test_lightning_cli_args(tmpdir): loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] - cli_config = cli.config["fit"] + cli_config = cli.config["fit"].as_dict() assert cli_config["seed_everything"] == 1234 assert "model" not in loaded_config and "model" not in cli_config # no arguments to include @@ -404,7 +404,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] - cli_config = cli.config["fit"] + cli_config = cli.config["fit"].as_dict() assert loaded_config["model"] == cli_config["model"] assert loaded_config["data"] == cli_config["data"] From 5fe0dac119ca1b1e5c3d0971e6263d1fbf0c586c Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Sat, 20 Nov 2021 06:26:50 +1300 Subject: [PATCH 06/59] Fix misleading ModelCheckpoint documentation on every_n_epochs parameter (#10421) --- pytorch_lightning/callbacks/model_checkpoint.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e195072c9718f..33f872f3a9f9b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -114,8 +114,14 @@ class ModelCheckpoint(Callback): guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. every_n_epochs: Number of epochs between checkpoints. - If ``every_n_epochs == None or every_n_epochs == 0``, we skip saving when the epoch ends. - To disable, set ``every_n_epochs = 0``. This value must be ``None`` or non-negative. + This value must be ``None`` or non-negative. + To disable saving after each epoch, set ``every_n_epochs = 0``. + If all of ``every_n_epochs``, ``every_n_train_steps`` and + ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch + (equivalent to ``every_n_epochs = 1``). + If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, + saving at the end of each epoch is disabled + (equivalent to ``every_n_epochs = 0``). This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` From a18b6409d13ded0e12752a9649bdb8dc763d4cbb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 19 Nov 2021 09:34:23 -0800 Subject: [PATCH 07/59] Check torch.distributed availability before sharded tensor state dict hook registration (#10621) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f5f3644b6846..96830878d84ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076)) +- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621)) + + - Fixed LigtningLite `_wrap_init` popping unexisting keys from DataLoader signature parameters ([#10613](https://github.com/PyTorchLightning/pytorch-lightning/pull/10613)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dc3ce5f0f4063..89f46949a525c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -46,7 +46,7 @@ ) from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp +from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.model_summary import ModelSummary, summarize @@ -1990,7 +1990,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ - if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS: + if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS or not torch.distributed.is_available(): + rank_zero_debug("Could not register sharded tensor state dict hooks") return from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook From af0bb96f0ff645102680a7adc99dc131cfeb9c0b Mon Sep 17 00:00:00 2001 From: puhuk Date: Sat, 20 Nov 2021 02:37:39 +0900 Subject: [PATCH 08/59] Remove the "_precision" suffix from some precision plugin files (#10052) --- pytorch_lightning/plugins/__init__.py | 4 ++-- pytorch_lightning/plugins/precision/__init__.py | 4 ++-- .../precision/{deepspeed_precision.py => deepspeed.py} | 0 .../plugins/precision/{ipu_precision.py => ipu.py} | 0 tests/plugins/test_deepspeed_plugin.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename pytorch_lightning/plugins/precision/{deepspeed_precision.py => deepspeed.py} (100%) rename pytorch_lightning/plugins/precision/{ipu_precision.py => ipu.py} (100%) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 0194591bfc06c..5ccbe8957694b 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -10,10 +10,10 @@ TrainingTypePluginsRegistry, ) from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin -from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin +from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin +from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index d055144a3f7e4..b407e47ca9337 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,10 +1,10 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 FullyShardedNativeMixedPrecisionPlugin, ) -from pytorch_lightning.plugins.precision.ipu_precision import IPUPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed.py similarity index 100% rename from pytorch_lightning/plugins/precision/deepspeed_precision.py rename to pytorch_lightning/plugins/precision/deepspeed.py diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu.py similarity index 100% rename from pytorch_lightning/plugins/precision/ipu_precision.py rename to pytorch_lightning/plugins/precision/ipu.py diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 480b050c39b36..397803e1d8a17 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -213,7 +213,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer = Trainer( fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16, track_grad_norm=2 ) - from pytorch_lightning.plugins.precision.deepspeed_precision import warning_cache + from pytorch_lightning.plugins.precision.deepspeed import warning_cache with pytest.warns(UserWarning, match="will be ignored since DeepSpeed handles the backward"): trainer.fit(model) From 8ea39d2c8f68cc33273c3431a310a262e2240cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 21 Nov 2021 02:33:13 +0100 Subject: [PATCH 09/59] LiteDataLoader code improvements and docs (#10625) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/lite/lite.py | 10 ++++++---- pytorch_lightning/lite/wrappers.py | 21 +++++++++------------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index ca88095dfc673..f5fdd0221cbe3 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -187,7 +187,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[Iterable, List[Iterable]]: + ) -> Union[DataLoader, List[DataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -212,7 +212,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Iterable: + ) -> DataLoader: """Setup a single dataloader for accelerated training. Args: @@ -245,7 +245,9 @@ def _setup_dataloader( dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None - return _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = cast(DataLoader, lite_dataloader) + return lite_dataloader def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 6b8e44b610352..3cd2f5eb69712 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -15,7 +15,7 @@ import inspect from contextlib import contextmanager from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -157,29 +157,26 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: - def __init__(self, dataloader: Union[Iterable, DataLoader], device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if - the device is specified. + def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: + """The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the + device automatically if the device is specified. Args: - dataloader: The current dataloader to be used. + dataloader: The dataloader to wrap device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ - super().__init__() - self.__dict__.update(getattr(dataloader, "__dict__", {})) + self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device - def __len__(self) -> Union[int, float]: - if isinstance(self._dataloader, Sized): - return len(self._dataloader) - return float("inf") - @property def device(self) -> Optional[torch.device]: return self._device + def __len__(self) -> int: + return len(self._dataloader) + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: From ce0a977742872736f150f6d37ecaa301a318668f Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Nov 2021 13:36:35 +0530 Subject: [PATCH 10/59] Moved `env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` (#10501) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 3 ++ .../trainer/connectors/env_vars_connector.py | 40 ------------------- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/utilities/argparse.py | 20 ++++++++++ 4 files changed, 24 insertions(+), 41 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/env_vars_connector.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 96830878d84ae..bba1cd319e706 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501)) + + - Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426)) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py deleted file mode 100644 index 4d130ca8e720b..0000000000000 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import wraps -from typing import Callable - -from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables - - -def _defaults_from_env_vars(fn: Callable) -> Callable: - """Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should - be moved automatically to the correct device.""" - - @wraps(fn) - def insert_env_defaults(self, *args, **kwargs): - cls = self.__class__ # get the class - if args: # inace any args passed move them to kwargs - # parse only the argument names - cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] - # convert args to kwargs - kwargs.update(dict(zip(cls_arg_names, args))) - env_variables = vars(parse_env_variables(cls)) - # update the kwargs by env variables - kwargs = dict(list(env_variables.items()) + list(kwargs.items())) - - # all args were already moved to kwargs - return fn(self, **kwargs) - - return insert_env_defaults diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f6e987635d47..38eb44bced223 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,6 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector @@ -75,6 +74,7 @@ rank_zero_warn, ) from pytorch_lightning.utilities.argparse import ( + _defaults_from_env_vars, add_argparse_args, from_argparse_args, parse_argparser, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 61443bea07cd7..ad707b036047a 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -16,6 +16,7 @@ from abc import ABC from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress +from functools import wraps from typing import Any, Callable, Dict, List, Tuple, Type, Union import pytorch_lightning as pl @@ -312,3 +313,22 @@ def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]: return int(x) except ValueError: return x + + +def _defaults_from_env_vars(fn: Callable) -> Callable: + @wraps(fn) + def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: + cls = self.__class__ # get the class + if args: # in case any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] + # convert args to kwargs + kwargs.update(dict(zip(cls_arg_names, args))) + env_variables = vars(parse_env_variables(cls)) + # update the kwargs by env variables + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) + + # all args were already moved to kwargs + return fn(self, **kwargs) + + return insert_env_defaults From eb13e1df89b70af0e90e3f1e9c4f94bff6728aee Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 22 Nov 2021 16:49:46 +0530 Subject: [PATCH 11/59] update bug_report model links and notebook (#10665) --- .github/ISSUE_TEMPLATE/bug_report.md | 10 +- .../advanced/fault_tolerant_training.rst | 2 +- pl_examples/bug_report/The_BoringModel.ipynb | 1420 ----------------- pl_examples/bug_report/bug_report_model.ipynb | 267 ++++ pl_examples/bug_report/bug_report_model.py | 1 + 5 files changed, 274 insertions(+), 1426 deletions(-) delete mode 100644 pl_examples/bug_report/The_BoringModel.ipynb create mode 100644 pl_examples/bug_report/bug_report_model.ipynb diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index b7f3574f62f99..546f5bc2ef8fa 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -16,11 +16,11 @@ assignees: '' Please reproduce using the BoringModel! You can use the following Colab link: -https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report/The_BoringModel.ipynb +https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report/bug_report_model.ipynb IMPORTANT: has to be public. or this simple template: -https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py +https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report/bug_report_model.py If you could not reproduce using the BoringModel and still think there's a bug, please post here but remember, bugs with code are fixed faster! @@ -46,9 +46,9 @@ python collect_env_details.py You can also fill out the list below manually. --> -- PyTorch Lightning Version (e.g., 1.3.0): -- PyTorch Version (e.g., 1.8) -- Python version: +- PyTorch Lightning Version (e.g., 1.5.0): +- PyTorch Version (e.g., 1.10): +- Python version (e.g., 3.9): - OS (e.g., Linux): - CUDA/cuDNN version: - GPU models and configuration: diff --git a/docs/source/advanced/fault_tolerant_training.rst b/docs/source/advanced/fault_tolerant_training.rst index e4a61b27e294d..63a3ce41ee8b3 100644 --- a/docs/source/advanced/fault_tolerant_training.rst +++ b/docs/source/advanced/fault_tolerant_training.rst @@ -134,7 +134,7 @@ Performance Impacts ------------------- Fault-tolerant Training was tested on common and worst-case scenarios in order to measure the impact of the internal state tracking on the total training time. -On tiny models like the `BoringModel and RandomDataset `_ +On tiny models like the `BoringModel and RandomDataset `_ which has virtually no data loading and processing overhead, we noticed up to 50% longer training time with fault tolerance enabled. In this worst-case scenario, fault-tolerant adds an overhead that is noticeable in comparison to the compute time for dataloading itself. However, for more realistic training workloads where data loading and preprocessing is more expensive, the constant overhead that fault tolerance adds becomes less noticeable or not noticeable at all. diff --git a/pl_examples/bug_report/The_BoringModel.ipynb b/pl_examples/bug_report/The_BoringModel.ipynb deleted file mode 100644 index 9b061c4283cbf..0000000000000 --- a/pl_examples/bug_report/The_BoringModel.ipynb +++ /dev/null @@ -1,1420 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "The BoringModel.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU", - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d79c1628eded487a974da18a2ea1f98b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_02695b143b764932ba8d0c08a872987e", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_28eb6a3218f64f26abcdff756ffda3ad", - "IPY_MODEL_02cfffd590014c3cbc44ab06c69f9181", - "IPY_MODEL_0d7c50e36cb84f01a57a9d7d8b913393" - ] - } - }, - "02695b143b764932ba8d0c08a872987e": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": "row wrap", - "width": "100%", - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": "inline-flex", - "left": null - } - }, - "28eb6a3218f64f26abcdff756ffda3ad": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_6ba2782883ae424dbfc8868224d95da9", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": "Epoch 0: 100%", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_baa4aacd0da64cf291fb31c000724573" - } - }, - "02cfffd590014c3cbc44ab06c69f9181": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_7dad3d2feced492a999fb6c91186be50", - "_dom_classes": [], - "description": "", - "_model_name": "FloatProgressModel", - "bar_style": "success", - "max": 2, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 2, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_ea702a091eb642f7bdda81aa55db8c26" - } - }, - "0d7c50e36cb84f01a57a9d7d8b913393": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_4802a47c6dfb439c83d8b860dce42006", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 2/2 [00:00<00:00, 9.45it/s, loss=-0.618, v_num=0]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_68c87e6a7fcf4e4eab98a941c7c3e867" - } - }, - "6ba2782883ae424dbfc8868224d95da9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "baa4aacd0da64cf291fb31c000724573": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "7dad3d2feced492a999fb6c91186be50": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "ea702a091eb642f7bdda81aa55db8c26": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": "2", - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "4802a47c6dfb439c83d8b860dce42006": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "68c87e6a7fcf4e4eab98a941c7c3e867": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "e6cbe583c2e14986b4faeb27e31f73e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_672dd78899f944cea7e57f388f3ecb31", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_cd61dda59d104e0a8a8aa9bfc1e55c24", - "IPY_MODEL_1cd72d82332941a6929f88fad5173096", - "IPY_MODEL_92a38638060c4ed5b6d44a2078667e53" - ] - } - }, - "672dd78899f944cea7e57f388f3ecb31": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": "row wrap", - "width": "100%", - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": "inline-flex", - "left": null - } - }, - "cd61dda59d104e0a8a8aa9bfc1e55c24": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_bdc9b06391ee47478efd58cc91ca87ac", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": "Validating: 0%", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_ee80657d62c6452d9e9ac199157cdf2a" - } - }, - "1cd72d82332941a6929f88fad5173096": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_eb16b87bcb8d4ca6a83e8b44ea2d1311", - "_dom_classes": [], - "description": "", - "_model_name": "FloatProgressModel", - "bar_style": "", - "max": 1, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 1, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_2a6327dd568241e3acbb6aec1926bd80" - } - }, - "92a38638060c4ed5b6d44a2078667e53": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_a45aba8517e14654850453159780b54a", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 0/1 [00:00<?, ?it/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_7fb167222e7143b789b7f40af7cb39dd" - } - }, - "bdc9b06391ee47478efd58cc91ca87ac": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "ee80657d62c6452d9e9ac199157cdf2a": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "eb16b87bcb8d4ca6a83e8b44ea2d1311": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "2a6327dd568241e3acbb6aec1926bd80": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": "2", - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "a45aba8517e14654850453159780b54a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "7fb167222e7143b789b7f40af7cb39dd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "abe1c0c4dac94e0e9b894bb69c3ec450": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_23763e19d40d4020b3342a47366e2e19", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_0b7b7da6a6134f0fb26a05adc062ee6f", - "IPY_MODEL_9941635d9d694ba7bce0c7a14c500e5e", - "IPY_MODEL_c7f1407ba92f4dc6ba34bd9cf73fea69" - ] - } - }, - "23763e19d40d4020b3342a47366e2e19": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": "row wrap", - "width": "100%", - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": "inline-flex", - "left": null - } - }, - "0b7b7da6a6134f0fb26a05adc062ee6f": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_86f2e0a558cc419e84ed9192ccd3d1b6", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": "Testing: 100%", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_141a9c35ade14d9e8645b2c108ab4d66" - } - }, - "9941635d9d694ba7bce0c7a14c500e5e": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_833bb79bb1214a3a88795f41b9375690", - "_dom_classes": [], - "description": "", - "_model_name": "FloatProgressModel", - "bar_style": "success", - "max": 1, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 1, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_51b32955ad544803b1d78f07bc685569" - } - }, - "c7f1407ba92f4dc6ba34bd9cf73fea69": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_a6b2764a5fa9444a9d77e8d74c67ef47", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 32/32 [00:00<00:00, 174.23it/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_3f0c08c03e284ebb905dae8aca72fffc" - } - }, - "86f2e0a558cc419e84ed9192ccd3d1b6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "141a9c35ade14d9e8645b2c108ab4d66": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "833bb79bb1214a3a88795f41b9375690": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "51b32955ad544803b1d78f07bc685569": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": "2", - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "a6b2764a5fa9444a9d77e8d74c67ef47": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "3f0c08c03e284ebb905dae8aca72fffc": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - } - } - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "rR4_BAUYs3Mb" - }, - "source": [ - "![image.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAQYAAABSCAYAAAC2XXppAAAaWElEQVR4Ae2dh3dU1RbG3x/i0mXvoighgKgI2BtW7KKiSAcFFVBQmgqKil2UJzYUQXl2ERSRpoiIIJhMOiEQkpCQRup+63eSHS73zkxmkmFC2Wetm5m55ZTv7P3tcs5M/iNWDAFDwBDwIfAf32f7aAgYAoaAGDGYEBgChkAAASOGACR2whAwBIwYTAYMAUMggIARQwASO2EIGAJGDCYDhoAhEEDAiCEAiZ0wBAwBIwaTAUPAEAggYMQQgMROGAKGgBGDyYAhYAgEEDBiCEBiJwwBQ8CIwWTAEDAEAggYMQQgsROGgCFgxGAyYAgYAgEEjBgCkNgJQ8AQMGIwGTAEDIEAAkYMAUjshCFgCBgxmAwYAoZAAAEjhgAkdsIQMASMGEwGDAFDIIBAhxJDXW2j7NldL6WFdVJb0xjonJ0wBAyBjkGgQ4ihoUGkvLhONi6vlLlPFMg7j++Qv5aVS0VZvdTXG0F0jChYq4bAPgSSSwyNIpV7GiTzrypZMKNIJvXLlhHdQzKye0ie6pcjC57bKRkbqqW6vF4ajR/2zZK9MwSSjEDSiGFvVaMUZNbI93NL5Lk7cmRUz5AMS0mX4Rxdm47RF2TIs3fmynfv7pYdmTWyt7ohyXBYc4aAIQACB5wYyCMUFdTJmsVlMntIvoy9OFOGdQ3JsNR0GdEttO9IDbUQxNjemfLqsHxZ9UWZ7C6slTrLP5i0GgJJReCAEQOhQNmuOvn7l0r574QCGX95lgyHELqmywg/KXgJovme4d1CMv6KbJk7YYdsWlHhkpQN9UnFxhozBI5YBBJPDJpH2FgtC2cVydM3ZMuIbhkyrEu6DI9GCF5y6NbkPUAio3pkyuQbc2TRrELJ3FQtVXuMHY5YabWBJw2BhBJDTVWjyw38MK9EZt6dJ6MvyGzyEuIghHDhBbmIhy/MlJn35MmS9y3/kDTpsIaOWAQSQgyN9SK7ttXK2q/L5PWR22Vsb2/Y4Mkj+LyC/Ugg2rXUphBkeGpIHuuT5dqgLfIP9bW2fHHESq8N/IAh0C5iII+wp7hetv5WKe9P2ikTroIQMppWG+L1ElD+1FZIpDn/MKpbSJ68Ols+eHqnbF5VKRWl9cLeCCuGgCGQGATaRQx7iurkqzdKZHr/HBnZo42EgKfQTApjLsqQUT0z9q1UhPMimlcvyD+M7pkh027JlR/nlUhZUV1iELFaDAFDoO3LlY0NIqEN1fJwrwjLj+GUOtw5FyaE5ImrsmXuuO0yplemEDK0GmZAECnpbpVjWv8c2fp7pU2nIWAIJAiBNnsMuO5b11bK0C4sP8agyFFIYcLVWfLd3BJZvqC0dULw15MSkonX5sjmXysSBIlVYwgYAu0jht8qY7PufmXmc2pIhqaky8TrcmTVwlIp3VUvC2YUykPn+DY+hXvWey61aTv15pWxEUNjY6Ps2bNH/v77b9mwYUPg4HxmZqaUlpbGLR01NTWSnZ0dqDNcO3ruzz//lFAoJBUVsfU/7k7F8ABj3bx5s6xfv162b98udXVtD8tyc3Pl999/d0dxcbGAd7RC27T722+/SX5+vjS0I1mUk5Pj6vnjjz86FE/w++uvv1xftm7dGm34B+21jiEGwocu6TKpX46sXFQqe6sbHTHMGrhNhpx3YImhvr5e1q1bJ7fddptce+217rj66qvlqquucsf1118vDz30kLz55pvyzz//SG1tbcyTh1LNmDFDrrzySqFOPWjnuuuuc8c111zTcp7rl19+uYwdO1Y2bdoUczuJvnHt2rUycOBAueKKK+T9999vEylqn+bMmeNwvOyyy2TFihWt4sdcgHmvXr1c21VVVVpV3K+vv/660C54o5CtkZK3AQhq586dcT3jfd77HsODfF188cXyyCOPeC8dMu+TTwyQQkq6TLo+R1YvLnOkQL4i79+9Mu7SrLg2Qbk8RJweA2yOwJ5//vnSqVMnOeecc+Tcc8+V8847z71y7owzzpAuXbrI4MGDBaWJ1Yrl5eXJ448/LmeffbZ07ty55aANDs7z6r125plnyr333uusZkdJzc8//+wU6rTTTpNXXnlFsPRtLS+88IIb31lnnSVLliwRvKhoZdWqVXLBBRfIKaecIm+88YZUVrY9VzR9+nShXeYSzy/WecO6z549W1577bVW+xttLHqtrKxMLrnkEjn11FPd3Or5Q+k1ucTQHD48dX2OrFpcJjXVTW4m34X44/ty912JuPMVbSQGhBEBuvHGG2XmzJnueO655+SJJ56Q/v37O+GGMEaMGCHbtm2LaU53794tixcvFgT02WeflWeeecYdDzzwgHTr1s21d8cdd8iUKVNark2dOlU++OADwQXvqAIx4OWcfvrp7SaGWbNmtRDsjz/+2KqiEUY9//zz8tRTT8ny5ctbvT8aRmAOsaekpDgPLFZieOyxx5wxGDRokLTHY9G+VVdXy0svveTGNG/ePD19SL0mlRj4evXUm3Nk9f/2kQJoVZc3yOezimTQOWnxJx/bQAy//vqrs1IIEESA61deXu4OrCUexX333eesT58+feSzzz6LycXEdUUoqM97fPzxx9K3b1/p2rWrU7yioqL9riOMhDhaeI9nE6tg63PRXqmPI1yBGAilIAasZklJiRsv98fjjlM3xADh4h3FQgy0AVZY2WjeBdf27t0brvst5yAG2gXnLVu2uL7H8hyeId4Sr4kgBjqEPDGm9nhALQPzvEE2wCHeefFUEdPbpBLDY30yZfmnpS2egvawuqJBFs8ulqFdQjKUL1l5k4utvW8HMSBAWCqvUtInBHXhwoXO8mDpp02b5rwGLBqkkZ6erl1veWWy0tLSBCUj+VZYWNhy7fPPP3euJUREHBxO+OgDzxBzf/PNN7Jo0SL57rvvXCITRfUWhG3jxo3y/fffu7YyMjJk9erV8sMPP7i8COREQSlIpNIn+oA3Q/9JkHpzJ0oMhFAvv/yya3PZsmWuD/QFtzxWAY+XGCBJcKXv9NU/F+RtCDfAY/78+Y5sGC8YQzwrV65sGYsSQ2pqqrsOlswjzxHW8JwqFN7dTz/9JEuXLnX5AEK6W265Rb799lt3LwlRf19Q9F9++cU9w3NgDpb+ArbgzJhIhIYryBDXqYd+EM6ARaSCDDCeL774Qj788EP5+uuvnRySNKYeMKR/iSpJIwb2JjzeN1N++rg08EUovjW5La1G5jxa0LRaEQ85tIMYUNRwxIBAIHA9e/Z01oe8AZ9vvvlmF4vzDOThLZDChAkTXMJpzJgxLsuv11FKPAbawyL7lQxSYXWCxOUNN9wg3bt3dzkO2r/99tudl0EyTQWV0IYQ5MILL5Q777zThTuEAuRFCGNIotE/lHrUqFFy6aWXurZpn/f0DzLQfigxEJ+TeCX0IXHG/YRcd911l3z66aeCMrVW4iUGViNIzDLW9957r6VPtAP5TZ48WUjYMjbyM/Rn5MiRDhf6yPi1X15iwBPs16+fC2sICbl32LBhTrnwxFBEzukck/dhvOSeIBZCSCVYHXNWVpYLuS666CLhIHkMvii21xsDe2SFurkerpBPoR3q6d27tyMl5h9y9BdWW8jdgBN9BAeeffDBB4XwR/uCDCaqJI0Y1AsgwfjJs4Wya9v+SSnIobigVr55u9h912JISoyeQzuIQT0GP5iEEySjmAAEhfckFp988knnbpO9x1prQbHJEzBBPXr0cCsauJJaohEDVh3PAAFHgJl4svQDBgxwgodbzjkEAHJCqBEUyIp4GoHGq8Ha3X333U65yFeQRIQEcK3J+KPctIGwQgAQ1SeffOIIRImB8dIe93MvpMR4qIOxQQ7ecen4vK/xEoMSsD/5SJ9QLjCBFMjygwHEyTnGzfhRTi8xMAa9ritM1MN5xsHzeA4cKP/999/viJhrEAVtQIyMg7nBw0AeduzY4aw/JEV4CSkTeik2eCVasNxgT/KRkFQLcgJpUxc5CGRLQ0zqot+Mk7yLFpSdvBTXmDf6P2TIEDc3zBXj4mCeCJ8SVZJLDGyESg253Y0fTimUwtwa2W+Vm99wKGmQpfNKZGwfdkDGQA7tIAYUbvjw4W69GdcU6wX7M2lYYCYCq8PKBJOKO4rFglAmTZrUEhIwkSxLcT/Kyf3qsjJR0YgBT2D06NFOyBFMEpbkQLBouPN4IUw6gvH000+7kAbF19UPlMb7DHsB6CdLdrjHCNJHH33kVj1wa7HKjIkxEFOjILih5BiUMFhy5F4OFASLRgz+8MMPt5okjZcYCBNQMurHirKfA8XCuiPw9JM616xZ48Ik7h8/frw7r8SgoRYeA89wnpUeyIUlZ+aDsUImKCAEh1X/999/XciBEoMVc8f+EtxzyJc5JCxAHrD84I9BIIT56quvZNy4ca5OcIOEdDVHiYG2IB4t1Es9ENI777zj6iIMIDQgCU49hHO0R6F9wljOMSa8IMIXCICwAg9DPQgI/9AlBvIFXUNuU9TYXpmy6MWiwLcjWaHYsKxCxvaNfWs0vxcZ6wYnXD5NPiIoKOOtt97qmJpXlAaQmVSsJEqim49w8xBYJgoLjavL5H355ZdOsWBwVjhUQFQgIhED1p+4Hw+EvuCRoKgII2EDFguCQJAQmptuuskpsRID51Bo7uEZxsZzKAheBEqFJ0PsqtdYr2f8tIuQ01f1GBgzeyoYp95Pf/BeEEysWWvuanuJgfAGoYeMGB/EzXgVE17pH33BWvs9Bs7RVzDXpC7P4JXhfmPFmSOwBSu8DTBkTrHEeESc18QvBoG2uM5qFffzLOcJLSATSAXiRh4okYgBo8OY6AOrFdTDQT8hCPpHXXh3FOaGfRnMC3IJWSkOzA8hJaEfdR76xAA5NHsO744rEPYwaKmvFflnTaW8MDBPRup9Byj5iJXCujARJ598sjtwZ5k0FApXGqXC7dMCQeB+IwQ8/9Zbb7mYmLgexSYGRNlUqPS5SMRADEsdCATuKZYMofQWhIZ7UHSUhWQaioK1QlgRTLwELQg2ysTYcHuVvPQ6RKaCjXDz2UsMJB+9xEb7WDnGh1VrbSdfIogBvBgv48PDoQ/eAkbkeSIRA1aUcXvnAWuN646SYYUZOwXSxDtUYvC35SUGjIE3lKJ+rDveDsqJN0OJRgzImxKDu7n5D+0yX/QPnCmQMB4h55hv73i4zjzieXD9sCGGUT1C8u2cfRtpIIUtayrl+QF5TV+/jvX7F+0IJQCd+BNhfvHFF91BcpDlSbLSWFfvZKBE7FAk9kYoiUe5j89YKSYPVveXSMSAMNAegow7S8bZX7AMKAfhBB4MxOQlBnYsFhQUtDxGn3GbIQbCiFjcSy8xvPrqq265UivEghPqJJsYIEsEHsvqV1b6RijHGMN5DOAJEXjnjs/qCULkSgy7du1qlRiYXxQaj8FLDPTj3XffbSEG8iWUSMQAzhAIROLf30BfIR6uQQzIGiSMEVJicJV7/uA90P5hRQyP9s2UTSuadrk5UlhbKc/fm+d+KDamb1eqJ9EOYkCAiFexGnqgWFhyFJLJ8RfiWZQZYiBxxOYYBA5rzn4FWNxfIhEDbRBf4n2gCFhr2vcWvAGsI0qAZWOlQUMJrBwxrJcYEHjyD5Ae+RDyFP4+kUhFUSA5vCDNMSBgkYgBAY3XY6BehDda8ecYICLyPZAgSgTJeb026mKJECWNRAx4fIwtFmJgiRhcUXxCCT8JgWc0Ypg7d25CiAEviDEpMfAZWcOTpG8QIPPmLXh2zP/hE0p0C8nkG3OlKJ9fYBLZAinc1wZSaA432ppjgBhQunAE4J0A73smjCQlwoQVhRQgCWI91vvDlUjEwL0kk1TwsPCQC8IKaSAIfGeDHATCMXToUOcBRCMG6iRDzsoCAsMzKBpkB0EQv0McxM0ks/Bw/B6DN5RQjyFeYkBp8W5YrycW9x7kCGgDxQ1HDFhcPCG8MJZuSa6SCIUQwIv+E2pECiXiIQaIGKWjLhK2uO+QmZJKa8SQKI8BufJ6DNoHPCNCXWSNZWbGz/yTJCUkwqDQ90M/lEht+gczr43Ml9rqxqbwoa2k0A5iUCsNMTAp8RQ23ZChRllRANx8Mup+N1PrVGJgEv37GEhm8Sx1oAgIKQLA+j1JR5JPeAaEGuQgsGgaStC+32OgTZJSrPXTN8jvnnvucQTIWEls4U0Q5xIf4ykpMVAfHoOfGFiNgBhIfsaSY6B+BJn7Wfpjvd170B/Ii7FADHgHeCvgABFB1Hg6+n0DyAHypC6Uh7rBSvHyL1eCcziPAQ8KLL2hBAQMVlhq+g2RkgSGTJELiIHlQkg2UihB3+mLP5RQj07lAJy5j/v9oYQSA9fwzNTTAm8I68QTT3RLyYyfEJY+MbfMmRqoWMJG7Utrr8ldrmxW5FHnZ8iXrxVL2rpKmXFvbvzhg4YRbSQGJggwmQRCgXiJAcuL4BJKIDCsZrDTLZLngUJj4WiPfIaucjA5PIPVJsdBOEICFCGlXl5RYLLmuKysgVOwnCQEjznmGGf5ISpvYTzkPkhCMk4sDgKEoCJgWBiUgXCCe9l9h4dBfXxvAbLQQl9x50866STnubD8F61g3ekz99NWuOOoo45yxMeSIeEGpHP00Ue7UEqxgTRY7UEBUCatj7FAlqwm+UMJyPr44493OLNSo1af/pKMBAvamThxYkuOgWtYYbw1sKf+4447zi3xopwQA0uRxx57rCNtP/lDZtTJdcZCwQuB6DmPZ6YFnLmP8yQNvYV5wChwDQ9RiYHzEI56UOCpWECoyMZB5zFsWV0pAzulybCUkPtXc7qJKeor+xguzpT5U3fKzAF5MjSlafky6jNeIvC+d1/fDsm4K7Jk44p9G4q8gPvfAzTMinCQLFywYMF+AuS/P9JnFASBQThZHvQqk/8ZQg88APIZrFogbN4COZAEQ3BQLCwjqyIoJISBVVWryHNYdPpN/xEwf16CexgnBEJowhIkqxdYap6BqHDnVfgYC+3iGbCs5xV++op1e/TRR93uO+8KiHcM+p7x4f6CL5Y33EEoQK6E8IZQA5eYcyznQbpKsFhzQhCSsuz8o4+47pAeG4dQCJRJsWFvAeERnhFek9ZD3+i3tsN91K0FrAgD3377bTdH4MC+Cc6DEfsFmAu+ZEefvYVlRiw4+QklTe7hGZ0fvR+y4j7uZ669BRJjfFzDo6FtCmPgGuQPtsgD44CQIAzagPAJJVrz5rzttfa+zR4DO5N2ba9r2cb84NlpMf9oCx7DuMuyZERqRszPBIgjNSSDzm760tX8aeykjJ7k8gLBxAE0wqJC5b3e2nueR7jwAlgCY9+61zr5n8c91vawkl6B1Xs5hxBCECgtm294JQHlFWLu5zNkQFJOY3Wtx/tKn2ibXAX1ET8zZr9wY52pi/uI771joV+0gaJxj5KJtx3ve57HsyEhyv2RDu6jbohHdwNCSHqO3AoKjgcDdlxjzNxPP3G38YIIMdTLoE5tz99PPmu/whEpY+Z5+s1YwY1Cf2gfggI7PnsLdTFPkDA46jPkiajLG5Zxnfu4P1wfaIdrtKOFkItNbhgWFJ86kFnmkHETGuJdkqTk+USVthODiLCNuTCvTlYuLpMZd+c5RR0U4y8wuZWHWJckvV5Ct5A81DldBndOk5kDcmXlojIpzmdjT3yQMMF6xPdkUwyPBYWpEVCEprWibfkFy/+c975o9+p9/ufDfdZ79TXaPfFei3a/thfu1fucXtdzKCnWE9efRCJYYy2xuIRw7DIlj4DbT2jmXXXx16V18hrtmt4X7h49x2u4otf918LdH+lentVr3uf48hghJZvnCCcwSIRFhD+ETuQZkEM8Gj8Z+vsTz+d2EYNrqFGkrlokd3O1fPlmsfsJ+QfOTJMh/BZk9zb+FqSPCNRb4NuXD3ZKk4nXZss3b5VIfjr/+BYFj2fI7bsX8PmGHclLYmNc5UROSPt6d/g8TdadDT+aEyGfw+4/fvEKwjjhhBPcZ7ajexXp8EGgaSR4HnhF5CbIUeEZgAO5BeQPfMi3QBSJLO0nhube8FN9lbsbJW1dlXwweaeM7pkpEITzDNpJEOQwCFXGXJQpH0/dKaE/qtz/kvDumkwkKNHqwtVnSy0JMSaH7yVYSTwChAysDBBDk5RFKUi64SUQwrFio3mAxLd+8NQI6RFCkHdglUYTyOCAt4AXQb4CvBJZEkYM2qmGOpGSgjr5c+keeXXYdhlybpNSx/3LTHgNzXkEfjT29VH58ufSctldyA+YaGvJfyW2YymMRBmCS1xq5cAgQAKOOJ1kI8lKkqZ8x0PX8o8UTw1ywHNgFYnkMDgQYvCdF3IWBwKHhBODikhdjUhBVq38/MlumdY/VwZ1SnO5gVjDi8HnprtfjJ5+e46sWFgqu/JqpY78YhLDBh2L/xWB1cN/zT4nFgGUgpwDFhFS5uDz4Rw+REKQcZN4VhyQwQOFwwEjBje4BpGqikbJ3bJX/vdKkVuJILzAA4hEEC5s6JQm4y/Pcv/lyuURKpKbR4g0MXbeEDhSEDiwxNCMIq5/eUm9/LuuSuY9uUNG9AjJA532X94cTh6hU5qMPj/D/U/K0PrmPMJB4CEcKcJg4zQEFIGkEIM2Rv6heHudrF+yR14elC+DO6e7JU7+yQz/T2L24G2yflm5lHZwHkH7a6+GwJGKQFKJQUGu3SuyI6dWfpq/W6bcnC1Tb8lxPxK7a3ut8EMtB0MeQftqr4bAkYhAhxCDA7pBpLqyUXZk18r2jBrZW2GEcCQKoI354ESg44ihGQ82JyVzg9LBOQ3WK0Pg4EKgw4nh4ILDemMIGAIgYMRgcmAIGAIBBIwYApDYCUPAEDBiMBkwBAyBAAJGDAFI7IQhYAgYMZgMGAKGQAABI4YAJHbCEDAEjBhMBgwBQyCAgBFDABI7YQgYAkYMJgOGgCEQQMCIIQCJnTAEDAEjBpMBQ8AQCCBgxBCAxE4YAoaAEYPJgCFgCAQQMGIIQGInDAFDwIjBZMAQMAQCCBgxBCCxE4aAIWDEYDJgCBgCAQSMGAKQ2AlDwBD4P9CuROTFaWXrAAAAAElFTkSuQmCC)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i7XbLCXGkll9" - }, - "source": [ - "# The Boring Model\n", - "Replicate a bug you experience, using this model.\n", - "\n", - "[Remember! we're always available for support on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2LODD6w9ixlT" - }, - "source": [ - "---\n", - "## Setup env" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zK7-Gg69kMnG" - }, - "source": [ - "%%capture\n", - "! pip install pytorch-lightning --upgrade" - ], - "execution_count": 1, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WvuSN5jEbY8P" - }, - "source": [ - "---\n", - "## Deps" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w4_TYnt_keJi" - }, - "source": [ - "import os\n", - "\n", - "import torch\n", - "from torch.utils.data import DataLoader, Dataset\n", - "\n", - "from pytorch_lightning import LightningModule, Trainer" - ], - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrJDukwPtUnS" - }, - "source": [ - "---\n", - "## Data\n", - "Random data is best for debugging. If you needs special tensor shapes or batch compositions or dataloaders, modify as needed" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "hvgTiaZpkvwS" - }, - "source": [ - "class RandomDataset(Dataset):\n", - " def __init__(self, size, num_samples):\n", - " self.len = num_samples\n", - " self.data = torch.randn(num_samples, size)\n", - "\n", - " def __getitem__(self, index):\n", - " return self.data[index]\n", - "\n", - " def __len__(self):\n", - " return self.len\n" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "sxVlWjGhl02D" - }, - "source": [ - "num_samples = 10000" - ], - "execution_count": 4, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "V7ELesz1kVQo" - }, - "source": [ - "class BoringModel(LightningModule):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.layer = torch.nn.Linear(32, 2)\n", - "\n", - " def forward(self, x):\n", - " return self.layer(x)\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " loss = self(batch).sum()\n", - " self.log(\"train_loss\", loss)\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " loss = self(batch).sum()\n", - " self.log(\"valid_loss\", loss)\n", - "\n", - " def test_step(self, batch, batch_idx):\n", - " loss = self(batch).sum()\n", - " self.log(\"test_loss\", loss)\n", - "\n", - " def configure_optimizers(self):\n", - " return torch.optim.SGD(self.layer.parameters(), lr=0.1)" - ], - "execution_count": 5, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ubvW3LGSupmt" - }, - "source": [ - "---\n", - "## Define the test\n", - "NOTE: in colab, set progress_bar_refresh_rate high or the screen will freeze because of the rapid tqdm update speed." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4Dk6Ykv8lI7X" - }, - "source": [ - "def run():\n", - " train_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", - " val_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", - " test_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", - "\n", - " model = BoringModel()\n", - " trainer = Trainer(\n", - " default_root_dir=os.getcwd(),\n", - " limit_train_batches=1,\n", - " limit_val_batches=1,\n", - " num_sanity_val_steps=0,\n", - " max_epochs=1,\n", - " enable_model_summary=False,\n", - " )\n", - " trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)\n", - " trainer.test(model, dataloaders=test_data)" - ], - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4dPfTZVgmgxz" - }, - "source": [ - "---\n", - "## Run Test" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "AAtq1hwSmjKe", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 272, - "referenced_widgets": [ - "d79c1628eded487a974da18a2ea1f98b", - "02695b143b764932ba8d0c08a872987e", - "28eb6a3218f64f26abcdff756ffda3ad", - "02cfffd590014c3cbc44ab06c69f9181", - "0d7c50e36cb84f01a57a9d7d8b913393", - "6ba2782883ae424dbfc8868224d95da9", - "baa4aacd0da64cf291fb31c000724573", - "7dad3d2feced492a999fb6c91186be50", - "ea702a091eb642f7bdda81aa55db8c26", - "4802a47c6dfb439c83d8b860dce42006", - "68c87e6a7fcf4e4eab98a941c7c3e867", - "e6cbe583c2e14986b4faeb27e31f73e1", - "672dd78899f944cea7e57f388f3ecb31", - "cd61dda59d104e0a8a8aa9bfc1e55c24", - "1cd72d82332941a6929f88fad5173096", - "92a38638060c4ed5b6d44a2078667e53", - "bdc9b06391ee47478efd58cc91ca87ac", - "ee80657d62c6452d9e9ac199157cdf2a", - "eb16b87bcb8d4ca6a83e8b44ea2d1311", - "2a6327dd568241e3acbb6aec1926bd80", - "a45aba8517e14654850453159780b54a", - "7fb167222e7143b789b7f40af7cb39dd", - "abe1c0c4dac94e0e9b894bb69c3ec450", - "23763e19d40d4020b3342a47366e2e19", - "0b7b7da6a6134f0fb26a05adc062ee6f", - "9941635d9d694ba7bce0c7a14c500e5e", - "c7f1407ba92f4dc6ba34bd9cf73fea69", - "86f2e0a558cc419e84ed9192ccd3d1b6", - "141a9c35ade14d9e8645b2c108ab4d66", - "833bb79bb1214a3a88795f41b9375690", - "51b32955ad544803b1d78f07bc685569", - "a6b2764a5fa9444a9d77e8d74c67ef47", - "3f0c08c03e284ebb905dae8aca72fffc" - ] - }, - "outputId": "59e8bcf2-a944-46fc-a771-e7cbbbe4727d" - }, - "source": [ - "run()" - ], - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "GPU available: True, used: False\n", - "TPU available: False, using: 0 TPU cores\n", - "IPU available: False, using: 0 IPUs\n", - "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py:1567: UserWarning: GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.\n", - " \"GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.\"\n", - "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/data_loading.py:395: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", - " f\"The number of training samples ({self.num_training_batches}) is smaller than the logging interval\"\n" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d79c1628eded487a974da18a2ea1f98b", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e6cbe583c2e14986b4faeb27e31f73e1", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "abe1c0c4dac94e0e9b894bb69c3ec450", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "Testing: 0it [00:00, ?it/s]" - ] - }, - "metadata": {} - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "--------------------------------------------------------------------------------\n", - "DATALOADER:0 TEST RESULTS\n", - "{'test_loss': -1.676544427871704}\n", - "--------------------------------------------------------------------------------\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Flyi--SpvsJN" - }, - "source": [ - "---\n", - "## Environment\n", - "Run this to get the environment details" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "0-yvGFRoaDSi" - }, - "source": [ - "%%capture\n", - "! wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/requirements/collect_env_details.py" - ], - "execution_count": 8, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "quj4LUDgmFvj", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "bb7a5f74-d52c-4927-b12a-49589aed7dcb" - }, - "source": [ - "! python collect_env_details.py" - ], - "execution_count": 9, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "* CUDA:\n", - "\t- GPU:\n", - "\t\t- Tesla K80\n", - "\t- available: True\n", - "\t- version: 11.1\n", - "* Packages:\n", - "\t- numpy: 1.19.5\n", - "\t- pyTorch_debug: False\n", - "\t- pyTorch_version: 1.10.0+cu111\n", - "\t- pytorch-lightning: 1.5.1\n", - "\t- tqdm: 4.62.3\n", - "* System:\n", - "\t- OS: Linux\n", - "\t- architecture:\n", - "\t\t- 64bit\n", - "\t\t- \n", - "\t- processor: x86_64\n", - "\t- python: 3.7.12\n", - "\t- version: #1 SMP Sat Jun 5 09:50:34 PDT 2021\n" - ] - } - ] - } - ] -} diff --git a/pl_examples/bug_report/bug_report_model.ipynb b/pl_examples/bug_report/bug_report_model.ipynb new file mode 100644 index 0000000000000..a6cb1933f113d --- /dev/null +++ b/pl_examples/bug_report/bug_report_model.ipynb @@ -0,0 +1,267 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "bug_report_model.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "rR4_BAUYs3Mb" + }, + "source": [ + "![image.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAQYAAABSCAYAAAC2XXppAAAaWElEQVR4Ae2dh3dU1RbG3x/i0mXvoighgKgI2BtW7KKiSAcFFVBQmgqKil2UJzYUQXl2ERSRpoiIIJhMOiEQkpCQRup+63eSHS73zkxmkmFC2Wetm5m55ZTv7P3tcs5M/iNWDAFDwBDwIfAf32f7aAgYAoaAGDGYEBgChkAAASOGACR2whAwBIwYTAYMAUMggIARQwASO2EIGAJGDCYDhoAhEEDAiCEAiZ0wBAwBIwaTAUPAEAggYMQQgMROGAKGgBGDyYAhYAgEEDBiCEBiJwwBQ8CIwWTAEDAEAggYMQQgsROGgCFgxGAyYAgYAgEEjBgCkNgJQ8AQMGIwGTAEDIEAAkYMAUjshCFgCBgxmAwYAoZAAAEjhgAkdsIQMASMGEwGDAFDIIBAhxJDXW2j7NldL6WFdVJb0xjonJ0wBAyBjkGgQ4ihoUGkvLhONi6vlLlPFMg7j++Qv5aVS0VZvdTXG0F0jChYq4bAPgSSSwyNIpV7GiTzrypZMKNIJvXLlhHdQzKye0ie6pcjC57bKRkbqqW6vF4ajR/2zZK9MwSSjEDSiGFvVaMUZNbI93NL5Lk7cmRUz5AMS0mX4Rxdm47RF2TIs3fmynfv7pYdmTWyt7ohyXBYc4aAIQACB5wYyCMUFdTJmsVlMntIvoy9OFOGdQ3JsNR0GdEttO9IDbUQxNjemfLqsHxZ9UWZ7C6slTrLP5i0GgJJReCAEQOhQNmuOvn7l0r574QCGX95lgyHELqmywg/KXgJovme4d1CMv6KbJk7YYdsWlHhkpQN9UnFxhozBI5YBBJPDJpH2FgtC2cVydM3ZMuIbhkyrEu6DI9GCF5y6NbkPUAio3pkyuQbc2TRrELJ3FQtVXuMHY5YabWBJw2BhBJDTVWjyw38MK9EZt6dJ6MvyGzyEuIghHDhBbmIhy/MlJn35MmS9y3/kDTpsIaOWAQSQgyN9SK7ttXK2q/L5PWR22Vsb2/Y4Mkj+LyC/Ugg2rXUphBkeGpIHuuT5dqgLfIP9bW2fHHESq8N/IAh0C5iII+wp7hetv5WKe9P2ikTroIQMppWG+L1ElD+1FZIpDn/MKpbSJ68Ols+eHqnbF5VKRWl9cLeCCuGgCGQGATaRQx7iurkqzdKZHr/HBnZo42EgKfQTApjLsqQUT0z9q1UhPMimlcvyD+M7pkh027JlR/nlUhZUV1iELFaDAFDoO3LlY0NIqEN1fJwrwjLj+GUOtw5FyaE5ImrsmXuuO0yplemEDK0GmZAECnpbpVjWv8c2fp7pU2nIWAIJAiBNnsMuO5b11bK0C4sP8agyFFIYcLVWfLd3BJZvqC0dULw15MSkonX5sjmXysSBIlVYwgYAu0jht8qY7PufmXmc2pIhqaky8TrcmTVwlIp3VUvC2YUykPn+DY+hXvWey61aTv15pWxEUNjY6Ps2bNH/v77b9mwYUPg4HxmZqaUlpbGLR01NTWSnZ0dqDNcO3ruzz//lFAoJBUVsfU/7k7F8ABj3bx5s6xfv162b98udXVtD8tyc3Pl999/d0dxcbGAd7RC27T722+/SX5+vjS0I1mUk5Pj6vnjjz86FE/w++uvv1xftm7dGm34B+21jiEGwocu6TKpX46sXFQqe6sbHTHMGrhNhpx3YImhvr5e1q1bJ7fddptce+217rj66qvlqquucsf1118vDz30kLz55pvyzz//SG1tbcyTh1LNmDFDrrzySqFOPWjnuuuuc8c111zTcp7rl19+uYwdO1Y2bdoUczuJvnHt2rUycOBAueKKK+T9999vEylqn+bMmeNwvOyyy2TFihWt4sdcgHmvXr1c21VVVVpV3K+vv/660C54o5CtkZK3AQhq586dcT3jfd77HsODfF188cXyyCOPeC8dMu+TTwyQQkq6TLo+R1YvLnOkQL4i79+9Mu7SrLg2Qbk8RJweA2yOwJ5//vnSqVMnOeecc+Tcc8+V8847z71y7owzzpAuXbrI4MGDBaWJ1Yrl5eXJ448/LmeffbZ07ty55aANDs7z6r125plnyr333uusZkdJzc8//+wU6rTTTpNXXnlFsPRtLS+88IIb31lnnSVLliwRvKhoZdWqVXLBBRfIKaecIm+88YZUVrY9VzR9+nShXeYSzy/WecO6z549W1577bVW+xttLHqtrKxMLrnkEjn11FPd3Or5Q+k1ucTQHD48dX2OrFpcJjXVTW4m34X44/ty912JuPMVbSQGhBEBuvHGG2XmzJnueO655+SJJ56Q/v37O+GGMEaMGCHbtm2LaU53794tixcvFgT02WeflWeeecYdDzzwgHTr1s21d8cdd8iUKVNark2dOlU++OADwQXvqAIx4OWcfvrp7SaGWbNmtRDsjz/+2KqiEUY9//zz8tRTT8ny5ctbvT8aRmAOsaekpDgPLFZieOyxx5wxGDRokLTHY9G+VVdXy0svveTGNG/ePD19SL0mlRj4evXUm3Nk9f/2kQJoVZc3yOezimTQOWnxJx/bQAy//vqrs1IIEESA61deXu4OrCUexX333eesT58+feSzzz6LycXEdUUoqM97fPzxx9K3b1/p2rWrU7yioqL9riOMhDhaeI9nE6tg63PRXqmPI1yBGAilIAasZklJiRsv98fjjlM3xADh4h3FQgy0AVZY2WjeBdf27t0brvst5yAG2gXnLVu2uL7H8hyeId4Sr4kgBjqEPDGm9nhALQPzvEE2wCHeefFUEdPbpBLDY30yZfmnpS2egvawuqJBFs8ulqFdQjKUL1l5k4utvW8HMSBAWCqvUtInBHXhwoXO8mDpp02b5rwGLBqkkZ6erl1veWWy0tLSBCUj+VZYWNhy7fPPP3euJUREHBxO+OgDzxBzf/PNN7Jo0SL57rvvXCITRfUWhG3jxo3y/fffu7YyMjJk9erV8sMPP7i8COREQSlIpNIn+oA3Q/9JkHpzJ0oMhFAvv/yya3PZsmWuD/QFtzxWAY+XGCBJcKXv9NU/F+RtCDfAY/78+Y5sGC8YQzwrV65sGYsSQ2pqqrsOlswjzxHW8JwqFN7dTz/9JEuXLnX5AEK6W265Rb799lt3LwlRf19Q9F9++cU9w3NgDpb+ArbgzJhIhIYryBDXqYd+EM6ARaSCDDCeL774Qj788EP5+uuvnRySNKYeMKR/iSpJIwb2JjzeN1N++rg08EUovjW5La1G5jxa0LRaEQ85tIMYUNRwxIBAIHA9e/Z01oe8AZ9vvvlmF4vzDOThLZDChAkTXMJpzJgxLsuv11FKPAbawyL7lQxSYXWCxOUNN9wg3bt3dzkO2r/99tudl0EyTQWV0IYQ5MILL5Q777zThTuEAuRFCGNIotE/lHrUqFFy6aWXurZpn/f0DzLQfigxEJ+TeCX0IXHG/YRcd911l3z66aeCMrVW4iUGViNIzDLW9957r6VPtAP5TZ48WUjYMjbyM/Rn5MiRDhf6yPi1X15iwBPs16+fC2sICbl32LBhTrnwxFBEzukck/dhvOSeIBZCSCVYHXNWVpYLuS666CLhIHkMvii21xsDe2SFurkerpBPoR3q6d27tyMl5h9y9BdWW8jdgBN9BAeeffDBB4XwR/uCDCaqJI0Y1AsgwfjJs4Wya9v+SSnIobigVr55u9h912JISoyeQzuIQT0GP5iEEySjmAAEhfckFp988knnbpO9x1prQbHJEzBBPXr0cCsauJJaohEDVh3PAAFHgJl4svQDBgxwgodbzjkEAHJCqBEUyIp4GoHGq8Ha3X333U65yFeQRIQEcK3J+KPctIGwQgAQ1SeffOIIRImB8dIe93MvpMR4qIOxQQ7ecen4vK/xEoMSsD/5SJ9QLjCBFMjygwHEyTnGzfhRTi8xMAa9ritM1MN5xsHzeA4cKP/999/viJhrEAVtQIyMg7nBw0AeduzY4aw/JEV4CSkTeik2eCVasNxgT/KRkFQLcgJpUxc5CGRLQ0zqot+Mk7yLFpSdvBTXmDf6P2TIEDc3zBXj4mCeCJ8SVZJLDGyESg253Y0fTimUwtwa2W+Vm99wKGmQpfNKZGwfdkDGQA7tIAYUbvjw4W69GdcU6wX7M2lYYCYCq8PKBJOKO4rFglAmTZrUEhIwkSxLcT/Kyf3qsjJR0YgBT2D06NFOyBFMEpbkQLBouPN4IUw6gvH000+7kAbF19UPlMb7DHsB6CdLdrjHCNJHH33kVj1wa7HKjIkxEFOjILih5BiUMFhy5F4OFASLRgz+8MMPt5okjZcYCBNQMurHirKfA8XCuiPw9JM616xZ48Ik7h8/frw7r8SgoRYeA89wnpUeyIUlZ+aDsUImKCAEh1X/999/XciBEoMVc8f+EtxzyJc5JCxAHrD84I9BIIT56quvZNy4ca5OcIOEdDVHiYG2IB4t1Es9ENI777zj6iIMIDQgCU49hHO0R6F9wljOMSa8IMIXCICwAg9DPQgI/9AlBvIFXUNuU9TYXpmy6MWiwLcjWaHYsKxCxvaNfWs0vxcZ6wYnXD5NPiIoKOOtt97qmJpXlAaQmVSsJEqim49w8xBYJgoLjavL5H355ZdOsWBwVjhUQFQgIhED1p+4Hw+EvuCRoKgII2EDFguCQJAQmptuuskpsRID51Bo7uEZxsZzKAheBEqFJ0PsqtdYr2f8tIuQ01f1GBgzeyoYp95Pf/BeEEysWWvuanuJgfAGoYeMGB/EzXgVE17pH33BWvs9Bs7RVzDXpC7P4JXhfmPFmSOwBSu8DTBkTrHEeESc18QvBoG2uM5qFffzLOcJLSATSAXiRh4okYgBo8OY6AOrFdTDQT8hCPpHXXh3FOaGfRnMC3IJWSkOzA8hJaEfdR76xAA5NHsO744rEPYwaKmvFflnTaW8MDBPRup9Byj5iJXCujARJ598sjtwZ5k0FApXGqXC7dMCQeB+IwQ8/9Zbb7mYmLgexSYGRNlUqPS5SMRADEsdCATuKZYMofQWhIZ7UHSUhWQaioK1QlgRTLwELQg2ysTYcHuVvPQ6RKaCjXDz2UsMJB+9xEb7WDnGh1VrbSdfIogBvBgv48PDoQ/eAkbkeSIRA1aUcXvnAWuN646SYYUZOwXSxDtUYvC35SUGjIE3lKJ+rDveDsqJN0OJRgzImxKDu7n5D+0yX/QPnCmQMB4h55hv73i4zjzieXD9sCGGUT1C8u2cfRtpIIUtayrl+QF5TV+/jvX7F+0IJQCd+BNhfvHFF91BcpDlSbLSWFfvZKBE7FAk9kYoiUe5j89YKSYPVveXSMSAMNAegow7S8bZX7AMKAfhBB4MxOQlBnYsFhQUtDxGn3GbIQbCiFjcSy8xvPrqq265UivEghPqJJsYIEsEHsvqV1b6RijHGMN5DOAJEXjnjs/qCULkSgy7du1qlRiYXxQaj8FLDPTj3XffbSEG8iWUSMQAzhAIROLf30BfIR6uQQzIGiSMEVJicJV7/uA90P5hRQyP9s2UTSuadrk5UlhbKc/fm+d+KDamb1eqJ9EOYkCAiFexGnqgWFhyFJLJ8RfiWZQZYiBxxOYYBA5rzn4FWNxfIhEDbRBf4n2gCFhr2vcWvAGsI0qAZWOlQUMJrBwxrJcYEHjyD5Ae+RDyFP4+kUhFUSA5vCDNMSBgkYgBAY3XY6BehDda8ecYICLyPZAgSgTJeb026mKJECWNRAx4fIwtFmJgiRhcUXxCCT8JgWc0Ypg7d25CiAEviDEpMfAZWcOTpG8QIPPmLXh2zP/hE0p0C8nkG3OlKJ9fYBLZAinc1wZSaA432ppjgBhQunAE4J0A73smjCQlwoQVhRQgCWI91vvDlUjEwL0kk1TwsPCQC8IKaSAIfGeDHATCMXToUOcBRCMG6iRDzsoCAsMzKBpkB0EQv0McxM0ks/Bw/B6DN5RQjyFeYkBp8W5YrycW9x7kCGgDxQ1HDFhcPCG8MJZuSa6SCIUQwIv+E2pECiXiIQaIGKWjLhK2uO+QmZJKa8SQKI8BufJ6DNoHPCNCXWSNZWbGz/yTJCUkwqDQ90M/lEht+gczr43Ml9rqxqbwoa2k0A5iUCsNMTAp8RQ23ZChRllRANx8Mup+N1PrVGJgEv37GEhm8Sx1oAgIKQLA+j1JR5JPeAaEGuQgsGgaStC+32OgTZJSrPXTN8jvnnvucQTIWEls4U0Q5xIf4ykpMVAfHoOfGFiNgBhIfsaSY6B+BJn7Wfpjvd170B/Ii7FADHgHeCvgABFB1Hg6+n0DyAHypC6Uh7rBSvHyL1eCcziPAQ8KLL2hBAQMVlhq+g2RkgSGTJELiIHlQkg2UihB3+mLP5RQj07lAJy5j/v9oYQSA9fwzNTTAm8I68QTT3RLyYyfEJY+MbfMmRqoWMJG7Utrr8ldrmxW5FHnZ8iXrxVL2rpKmXFvbvzhg4YRbSQGJggwmQRCgXiJAcuL4BJKIDCsZrDTLZLngUJj4WiPfIaucjA5PIPVJsdBOEICFCGlXl5RYLLmuKysgVOwnCQEjznmGGf5ISpvYTzkPkhCMk4sDgKEoCJgWBiUgXCCe9l9h4dBfXxvAbLQQl9x50866STnubD8F61g3ekz99NWuOOoo45yxMeSIeEGpHP00Ue7UEqxgTRY7UEBUCatj7FAlqwm+UMJyPr44493OLNSo1af/pKMBAvamThxYkuOgWtYYbw1sKf+4447zi3xopwQA0uRxx57rCNtP/lDZtTJdcZCwQuB6DmPZ6YFnLmP8yQNvYV5wChwDQ9RiYHzEI56UOCpWECoyMZB5zFsWV0pAzulybCUkPtXc7qJKeor+xguzpT5U3fKzAF5MjSlafky6jNeIvC+d1/fDsm4K7Jk44p9G4q8gPvfAzTMinCQLFywYMF+AuS/P9JnFASBQThZHvQqk/8ZQg88APIZrFogbN4COZAEQ3BQLCwjqyIoJISBVVWryHNYdPpN/xEwf16CexgnBEJowhIkqxdYap6BqHDnVfgYC+3iGbCs5xV++op1e/TRR93uO+8KiHcM+p7x4f6CL5Y33EEoQK6E8IZQA5eYcyznQbpKsFhzQhCSsuz8o4+47pAeG4dQCJRJsWFvAeERnhFek9ZD3+i3tsN91K0FrAgD3377bTdH4MC+Cc6DEfsFmAu+ZEefvYVlRiw4+QklTe7hGZ0fvR+y4j7uZ669BRJjfFzDo6FtCmPgGuQPtsgD44CQIAzagPAJJVrz5rzttfa+zR4DO5N2ba9r2cb84NlpMf9oCx7DuMuyZERqRszPBIgjNSSDzm760tX8aeykjJ7k8gLBxAE0wqJC5b3e2nueR7jwAlgCY9+61zr5n8c91vawkl6B1Xs5hxBCECgtm294JQHlFWLu5zNkQFJOY3Wtx/tKn2ibXAX1ET8zZr9wY52pi/uI771joV+0gaJxj5KJtx3ve57HsyEhyv2RDu6jbohHdwNCSHqO3AoKjgcDdlxjzNxPP3G38YIIMdTLoE5tz99PPmu/whEpY+Z5+s1YwY1Cf2gfggI7PnsLdTFPkDA46jPkiajLG5Zxnfu4P1wfaIdrtKOFkItNbhgWFJ86kFnmkHETGuJdkqTk+USVthODiLCNuTCvTlYuLpMZd+c5RR0U4y8wuZWHWJckvV5Ct5A81DldBndOk5kDcmXlojIpzmdjT3yQMMF6xPdkUwyPBYWpEVCEprWibfkFy/+c975o9+p9/ufDfdZ79TXaPfFei3a/thfu1fucXtdzKCnWE9efRCJYYy2xuIRw7DIlj4DbT2jmXXXx16V18hrtmt4X7h49x2u4otf918LdH+lentVr3uf48hghJZvnCCcwSIRFhD+ETuQZkEM8Gj8Z+vsTz+d2EYNrqFGkrlokd3O1fPlmsfsJ+QfOTJMh/BZk9zb+FqSPCNRb4NuXD3ZKk4nXZss3b5VIfjr/+BYFj2fI7bsX8PmGHclLYmNc5UROSPt6d/g8TdadDT+aEyGfw+4/fvEKwjjhhBPcZ7ajexXp8EGgaSR4HnhF5CbIUeEZgAO5BeQPfMi3QBSJLO0nhube8FN9lbsbJW1dlXwweaeM7pkpEITzDNpJEOQwCFXGXJQpH0/dKaE/qtz/kvDumkwkKNHqwtVnSy0JMSaH7yVYSTwChAysDBBDk5RFKUi64SUQwrFio3mAxLd+8NQI6RFCkHdglUYTyOCAt4AXQb4CvBJZEkYM2qmGOpGSgjr5c+keeXXYdhlybpNSx/3LTHgNzXkEfjT29VH58ufSctldyA+YaGvJfyW2YymMRBmCS1xq5cAgQAKOOJ1kI8lKkqZ8x0PX8o8UTw1ywHNgFYnkMDgQYvCdF3IWBwKHhBODikhdjUhBVq38/MlumdY/VwZ1SnO5gVjDi8HnprtfjJ5+e46sWFgqu/JqpY78YhLDBh2L/xWB1cN/zT4nFgGUgpwDFhFS5uDz4Rw+REKQcZN4VhyQwQOFwwEjBje4BpGqikbJ3bJX/vdKkVuJILzAA4hEEC5s6JQm4y/Pcv/lyuURKpKbR4g0MXbeEDhSEDiwxNCMIq5/eUm9/LuuSuY9uUNG9AjJA532X94cTh6hU5qMPj/D/U/K0PrmPMJB4CEcKcJg4zQEFIGkEIM2Rv6heHudrF+yR14elC+DO6e7JU7+yQz/T2L24G2yflm5lHZwHkH7a6+GwJGKQFKJQUGu3SuyI6dWfpq/W6bcnC1Tb8lxPxK7a3ut8EMtB0MeQftqr4bAkYhAhxCDA7pBpLqyUXZk18r2jBrZW2GEcCQKoI354ESg44ihGQ82JyVzg9LBOQ3WK0Pg4EKgw4nh4ILDemMIGAIgYMRgcmAIGAIBBIwYApDYCUPAEDBiMBkwBAyBAAJGDAFI7IQhYAgYMZgMGAKGQAABI4YAJHbCEDAEjBhMBgwBQyCAgBFDABI7YQgYAkYMJgOGgCEQQMCIIQCJnTAEDAEjBpMBQ8AQCCBgxBCAxE4YAoaAEYPJgCFgCAQQMGIIQGInDAFDwIjBZMAQMAQCCBgxBCCxE4aAIWDEYDJgCBgCAQSMGAKQ2AlDwBD4P9CuROTFaWXrAAAAAElFTkSuQmCC)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i7XbLCXGkll9" + }, + "source": [ + "# The Boring Model\n", + "Replicate a bug you experience, using this model.\n", + "\n", + "[Remember! we're always available for support on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2LODD6w9ixlT" + }, + "source": [ + "---\n", + "## Setup env" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zK7-Gg69kMnG" + }, + "source": [ + "%%capture\n", + "! pip install -qU pytorch-lightning" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WvuSN5jEbY8P" + }, + "source": [ + "---\n", + "## Deps" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "w4_TYnt_keJi" + }, + "source": [ + "import os\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "from pytorch_lightning import LightningModule, Trainer" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XrJDukwPtUnS" + }, + "source": [ + "---\n", + "## Data\n", + "Random data is best for debugging. If you needs special tensor shapes or batch compositions or dataloaders, modify as needed" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hvgTiaZpkvwS" + }, + "source": [ + "class RandomDataset(Dataset):\n", + " def __init__(self, size, num_samples):\n", + " self.len = num_samples\n", + " self.data = torch.randn(num_samples, size)\n", + "\n", + " def __getitem__(self, index):\n", + " return self.data[index]\n", + "\n", + " def __len__(self):\n", + " return self.len" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "sxVlWjGhl02D" + }, + "source": [ + "num_samples = 10000" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "V7ELesz1kVQo" + }, + "source": [ + "class BoringModel(LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.layer = torch.nn.Linear(32, 2)\n", + "\n", + " def forward(self, x):\n", + " return self.layer(x)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " loss = self(batch).sum()\n", + " self.log(\"train_loss\", loss)\n", + " return {\"loss\": loss}\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " loss = self(batch).sum()\n", + " self.log(\"valid_loss\", loss)\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " loss = self(batch).sum()\n", + " self.log(\"test_loss\", loss)\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.SGD(self.layer.parameters(), lr=0.1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ubvW3LGSupmt" + }, + "source": [ + "---\n", + "## Define the test" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4Dk6Ykv8lI7X" + }, + "source": [ + "def run():\n", + " train_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", + " val_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", + " test_data = DataLoader(RandomDataset(32, 64), batch_size=2)\n", + "\n", + " model = BoringModel()\n", + " trainer = Trainer(\n", + " default_root_dir=os.getcwd(),\n", + " limit_train_batches=1,\n", + " limit_val_batches=1,\n", + " limit_test_batches=1,\n", + " num_sanity_val_steps=0,\n", + " max_epochs=1,\n", + " enable_model_summary=False,\n", + " )\n", + " trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)\n", + " trainer.test(model, dataloaders=test_data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4dPfTZVgmgxz" + }, + "source": [ + "---\n", + "## Run Test" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AAtq1hwSmjKe" + }, + "source": [ + "run()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Flyi--SpvsJN" + }, + "source": [ + "---\n", + "## Environment\n", + "Run this to get the environment details" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0-yvGFRoaDSi" + }, + "source": [ + "%%capture\n", + "! wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/requirements/collect_env_details.py" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "quj4LUDgmFvj" + }, + "source": [ + "! python collect_env_details.py" + ], + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/pl_examples/bug_report/bug_report_model.py b/pl_examples/bug_report/bug_report_model.py index 270b0cd2abe8d..7739630237d32 100644 --- a/pl_examples/bug_report/bug_report_model.py +++ b/pl_examples/bug_report/bug_report_model.py @@ -53,6 +53,7 @@ def run(): default_root_dir=os.getcwd(), limit_train_batches=1, limit_val_batches=1, + limit_test_batches=1, num_sanity_val_steps=0, max_epochs=1, enable_model_summary=False, From 1284ead317d25b4d321c72d4f79e913e157022d3 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Nov 2021 19:59:06 +0530 Subject: [PATCH 12/59] Remove metrics references from docs (#10567) --- docs/source/advanced/multi_gpu.rst | 2 +- docs/source/extensions/logging.rst | 2 +- docs/source/extensions/metrics.rst | 9 --------- docs/source/index.rst | 1 - pyproject.toml | 1 - 5 files changed, 2 insertions(+), 13 deletions(-) delete mode 100644 docs/source/extensions/metrics.rst diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 3e4a7b2335b62..8384209aff0e8 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -90,7 +90,7 @@ This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the valid This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers. The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training. -Note if you use any built in metrics or custom metrics that use the :doc:`Metrics API <../extensions/metrics>`, these do not need to be updated and are automatically handled for you. +Note if you use any built in metrics or custom metrics that use `TorchMetrics `_, these do not need to be updated and are automatically handled for you. .. testcode:: diff --git a/docs/source/extensions/logging.rst b/docs/source/extensions/logging.rst index 1facdb93373eb..e652adbecc419 100644 --- a/docs/source/extensions/logging.rst +++ b/docs/source/extensions/logging.rst @@ -111,7 +111,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a .. note:: - Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a - reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction. + reduction in ``on_train_epoch_end``. We recommend using `TorchMetrics `_, when working with custom reduction. - Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor` diff --git a/docs/source/extensions/metrics.rst b/docs/source/extensions/metrics.rst deleted file mode 100644 index 74a4a15deb2be..0000000000000 --- a/docs/source/extensions/metrics.rst +++ /dev/null @@ -1,9 +0,0 @@ -####### -Metrics -####### - -``pytorch_lightning.metrics`` has been moved to a separate package `TorchMetrics `_. -We will preserve compatibility for the next few releases, nevertheless, we encourage users to update to use this stand-alone package. - -.. warning:: - ``pytorch_lightning.metrics`` is deprecated from v1.3 and will be removed in v1.5. diff --git a/docs/source/index.rst b/docs/source/index.rst index 72da9c3e354c4..c1b20b958591b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -84,7 +84,6 @@ PyTorch Lightning extensions/callbacks extensions/datamodules extensions/logging - extensions/metrics extensions/plugins extensions/loops diff --git a/pyproject.toml b/pyproject.toml index 08b7b50eee770..c527ffaa856cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,6 @@ module = [ "pytorch_lightning.core.*", "pytorch_lightning.loggers.*", "pytorch_lightning.loops.*", - "pytorch_lightning.metrics.*", "pytorch_lightning.overrides.*", "pytorch_lightning.plugins.environments.*", "pytorch_lightning.plugins.training_type.*", From 48cb38ac5dd0159c8f7c5189c888dfd04a2ed34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Nov 2021 15:52:21 +0100 Subject: [PATCH 13/59] Fix docs filterwarnings snippet (#10671) --- docs/source/guides/speed.rst | 4 +--- pytorch_lightning/trainer/data_loading.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/guides/speed.rst b/docs/source/guides/speed.rst index 576589c670635..04613a89bb35a 100644 --- a/docs/source/guides/speed.rst +++ b/docs/source/guides/speed.rst @@ -151,9 +151,7 @@ For debugging purposes or for dataloaders that load very small datasets, it is d import warnings - warnings.filterwarnings( - "ignore", ".*does not have many workers. Consider increasing the value of the `num_workers` argument*" - ) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") Spawn """"" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 85a1df650d179..d7354a8294b37 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -104,6 +104,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: ) elif dataloader.num_workers <= 2 < num_cpus and not using_spawn: + # if changed, update the `filterwarnings` snippet in 'speed.html#num-workers' rank_zero_warn( f"The dataloader, {name}, does not have many workers which may be a bottleneck." " Consider increasing the value of the `num_workers` argument`" From 991cd895c68140f87518b640eff3b3b177eebddb Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 22 Nov 2021 14:58:23 +0000 Subject: [PATCH 14/59] 1/n Add `FaultTolerantMode` (#10645) --- CHANGELOG.md | 3 +- pytorch_lightning/trainer/states.py | 6 +++- pytorch_lightning/utilities/auto_restart.py | 1 - pytorch_lightning/utilities/enums.py | 36 +++++++++++++++++++++ tests/utilities/test_auto_restart.py | 23 ++++++++++++- 5 files changed, 65 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bba1cd319e706..e678be0e965fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) -- +- Fault Tolerant Manual + * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) - diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 7f83dd76156ab..a81073cccc1c0 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from pytorch_lightning.utilities import LightningEnum +from pytorch_lightning.utilities.enums import _FaultTolerantMode class TrainerStatus(LightningEnum): @@ -93,6 +94,9 @@ class TrainerState: fn: Optional[TrainerFn] = None stage: Optional[RunningStage] = None + # detect the fault tolerant flag + _fault_tolerant_mode: _FaultTolerantMode = field(default_factory=_FaultTolerantMode.detect_current_mode) + @property def finished(self) -> bool: return self.status == TrainerStatus.FINISHED diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ef52717636d90..228e16e4e9c8c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index cbb4f68bedfac..1d7a6e3fa5452 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Enumerated utilities.""" +import os from enum import Enum, EnumMeta from typing import Any, List, Optional, Union +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -256,3 +258,37 @@ def interactive_compatible_types() -> List["_StrategyType"]: def is_interactive_compatible(self) -> bool: """Returns whether self is interactive compatible.""" return self in _StrategyType.interactive_compatible_types() + + +class _FaultTolerantMode(LightningEnum): + + DISABLED = "disabled" + AUTOMATIC = "automatic" + MANUAL = "manual" + + @property + def is_enabled(self) -> bool: + return self is not _FaultTolerantMode.DISABLED + + @property + def is_automatic(self) -> bool: + return self is _FaultTolerantMode.AUTOMATIC + + @property + def is_manual(self) -> bool: + return self is _FaultTolerantMode.MANUAL + + @classmethod + def detect_current_mode(cls) -> "_FaultTolerantMode": + """This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantMode`.""" + env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower() + # the int values are kept for backwards compatibility, but long-term we want to keep only the strings + if env_value in ("0", "disabled"): + return _FaultTolerantMode.DISABLED + elif env_value in ("1", "automatic"): + return _FaultTolerantMode.AUTOMATIC + elif env_value in ("2", "manual"): + return _FaultTolerantMode.MANUAL + raise MisconfigurationException( + "The environment flag `PL_FAULT_TOLERANT_TRAINING` should be either 'disabled', 'automatic', or 'manual'." + ) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b36a9d1d76941..b9eb97cb42ae8 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -35,6 +35,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, @@ -44,7 +45,7 @@ FastForwardSampler, MergedIteratorState, ) -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -1192,3 +1193,23 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +def test_fault_tolerant_mode_enum(): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): + assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() + assert not TrainerState()._fault_tolerant_mode.is_enabled + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): + assert _FaultTolerantMode.AUTOMATIC == _FaultTolerantMode.detect_current_mode() + assert TrainerState()._fault_tolerant_mode.is_automatic + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "MANUAL"}): + assert _FaultTolerantMode.MANUAL == _FaultTolerantMode.detect_current_mode() + assert TrainerState()._fault_tolerant_mode.is_manual + + with pytest.raises( + MisconfigurationException, match="The environment flag `PL_FAULT_TOLERANT_TRAINING` should be either" + ): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): + _FaultTolerantMode.detect_current_mode() From a6dedcf492456c1ef4e29b084b2b970f79667f17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Nov 2021 16:58:21 +0100 Subject: [PATCH 15/59] Fix `move_metrics_to_cpu` with evaluation (#10631) --- CHANGELOG.md | 4 ++-- .../loops/epoch/evaluation_epoch_loop.py | 12 ++++++---- .../logging_/test_eval_loop_logging.py | 24 +++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e678be0e965fe..6942ffc542143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -171,10 +171,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) -- +- Fixed `Trainer(move_metrics_to_cpu=True)` not moving the evaluation logged results to CPU ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) -- +- Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index b4660c96a0989..102603f20302b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -24,7 +24,6 @@ from pytorch_lightning.trainer.progress import BatchProgress from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher -from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -134,10 +133,13 @@ def advance( self.trainer.logger_connector.update_eval_step_metrics() # track epoch level outputs - if self._should_track_batch_outputs_for_epoch_end(): - output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu) - if output is not None: - self.outputs.append(output) + if self._should_track_batch_outputs_for_epoch_end() and output is not None: + self.outputs.append(output) + + if self.trainer.move_metrics_to_cpu: + # the evaluation step output is not moved as they are not considered "metrics" + assert self.trainer._results is not None + self.trainer._results.cpu() if not self.batch_progress.is_last_batch: # if fault tolerant is enabled and process has been notified, exit. diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 88229effbc8c9..66c91eaf15f1b 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -26,6 +26,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf def test__validation_step__log(tmpdir): @@ -699,3 +700,26 @@ def test_filter_metrics_for_dataloader(kwargs, expected): """Logged metrics should only include metrics from the concerned dataloader.""" actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) assert actual == expected + + +@RunIf(min_gpus=1) +def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir): + class TestModel(BoringModel): + def validation_step(self, *args): + x = torch.tensor(2.0, requires_grad=True, device=self.device) + y = x * 2 + assert x.requires_grad is True + assert y.grad_fn is None # disabled by validation + + self.log("foo", y) + return y + + def validation_epoch_end(self, outputs): + # the step outputs were not moved + assert all(o.device == self.device for o in outputs), outputs + # but the logging results were + assert self.trainer.callback_metrics["foo"].device.type == "cpu" + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1) + trainer.validate(model, verbose=False) From 6810c40fc9188ac76e4daccdf42503ab0f2b2c5e Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Mon, 22 Nov 2021 11:38:09 -0500 Subject: [PATCH 16/59] Small improvements to `_init_debugging_flags` (#10620) --- .../trainer/connectors/callback_connector.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 23 ++++++++----------- tests/trainer/test_dataloaders.py | 7 ++---- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 03926aeb9bc68..55730c0d79c72 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -94,7 +94,7 @@ def on_trainer_init( " bar pass `enable_progress_bar = False` to the Trainer." ) - self.configure_progress_bar(progress_bar_refresh_rate, process_position, enable_progress_bar) + self._configure_progress_bar(progress_bar_refresh_rate, process_position, enable_progress_bar) # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary, weights_summary) @@ -211,7 +211,7 @@ def _configure_swa_callbacks(self): if not existing_swa: self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks - def configure_progress_bar( + def _configure_progress_bar( self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True ) -> None: progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 38eb44bced223..cb446632cc9e3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -584,29 +584,24 @@ def _init_debugging_flags( overfit_batches, fast_dev_run, ): - if not isinstance(fast_dev_run, (bool, int)): - raise MisconfigurationException( - f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be either a bool or an int >= 0" - ) - if isinstance(fast_dev_run, int) and (fast_dev_run < 0): raise MisconfigurationException( f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0." ) self.fast_dev_run = fast_dev_run - fast_dev_run = int(fast_dev_run) # set fast_dev_run=True when it is 1, used while logging if fast_dev_run == 1: self.fast_dev_run = True if fast_dev_run: - limit_train_batches = fast_dev_run - limit_val_batches = fast_dev_run - limit_test_batches = fast_dev_run - limit_predict_batches = fast_dev_run - self.fit_loop.max_steps = fast_dev_run + num_batches = int(fast_dev_run) + limit_train_batches = num_batches + limit_val_batches = num_batches + limit_test_batches = num_batches + limit_predict_batches = num_batches + self.fit_loop.max_steps = num_batches self.num_sanity_val_steps = 0 self.fit_loop.max_epochs = 1 val_check_interval = 1.0 @@ -615,7 +610,7 @@ def _init_debugging_flags( rank_zero_info( "Running in fast_dev_run mode: will run a full train," - f" val, test and prediction loop using {fast_dev_run} batch(es)." + f" val, test and prediction loop using {num_batches} batch(es)." ) self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") @@ -624,9 +619,9 @@ def _init_debugging_flags( self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") self.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") self.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches") - self.determine_data_use_amount(self.overfit_batches) + self._determine_data_use_amount(self.overfit_batches) - def determine_data_use_amount(self, overfit_batches: float) -> None: + def _determine_data_use_amount(self, overfit_batches: float) -> None: """Use less data for debugging purposes.""" if overfit_batches > 0: self.limit_train_batches = overfit_batches diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 93a98bbff4d34..59c88e972d6a6 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -505,7 +505,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v assert mocked.call_count == limit_test_batches * len(trainer.test_dataloaders) -@pytest.mark.parametrize("fast_dev_run", [True, 1, 3, -1, "temp"]) +@pytest.mark.parametrize("fast_dev_run", [True, 1, 3, -1]) def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): """Verify num_batches for train, val & test dataloaders passed with fast_dev_run.""" model = EvalModelTemplate() @@ -518,10 +518,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): trainer_options = dict(default_root_dir=tmpdir, max_epochs=2, fast_dev_run=fast_dev_run) - if fast_dev_run == "temp": - with pytest.raises(MisconfigurationException, match="either a bool or an int"): - Trainer(**trainer_options) - elif fast_dev_run == -1: + if fast_dev_run == -1: with pytest.raises(MisconfigurationException, match="should be >= 0"): Trainer(**trainer_options) else: From d431ce14a110a0e42e03ba7132faec437d906c7f Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 22 Nov 2021 22:25:19 +0530 Subject: [PATCH 17/59] Raise an error if batch_size cannot be inferred from current batch (#10541) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 10 ++++-- .../connectors/logger_connector/result.py | 7 ++-- pytorch_lightning/utilities/data.py | 34 ++++++++++++------- tests/utilities/test_data.py | 21 ++++++++++-- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6942ffc542143..7fc3178b4426e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Raise exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) +- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) - The `monitor` argument in the `EarlyStopping` callback is no longer optional ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328)) @@ -35,7 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Do not fail if batch size could not be inferred for logging when using DeepSpeed ([#10438](https://github.com/PyTorchLightning/pytorch-lightning/issues/10438)) -- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Raised `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) - Moved `trainer.connectors.env_vars_connector._defaults_from_env_vars` to `utilities.argsparse._defaults_from_env_vars` ([#10501](https://github.com/PyTorchLightning/pytorch-lightning/pull/10501)) @@ -50,6 +50,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) +- Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) + + +- + + - diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 1b82baf0440c9..ab3c0f1804c2a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -410,11 +410,8 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int batch_size = 1 if self.batch is not None and meta.on_epoch and meta.is_mean_reduction: - try: - batch_size = extract_batch_size(self.batch) - self.batch_size = batch_size - except RecursionError: - pass + batch_size = extract_batch_size(self.batch) + self.batch_size = batch_size return batch_size diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index bbe41217f1346..5b56940460ca4 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -40,16 +40,14 @@ def _extract_batch_size(batch: BType) -> Generator[int, None, None]: yield 1 else: yield batch.size(0) - elif isinstance(batch, str): - yield len(batch) - elif isinstance(batch, (Iterable, Mapping)): + elif isinstance(batch, (Iterable, Mapping)) and not isinstance(batch, str): if isinstance(batch, Mapping): batch = batch.values() for sample in batch: yield from _extract_batch_size(sample) else: - yield 1 + yield None def extract_batch_size(batch: BType) -> int: @@ -58,16 +56,26 @@ def extract_batch_size(batch: BType) -> int: Returns: ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. """ + error_msg = ( + "We could not infer the batch_size from the batch. Either simplify its structure" + " or provide the batch_size as `self.log(..., batch_size=batch_size)`." + ) batch_size = None - for bs in _extract_batch_size(batch): - if batch_size is None: - batch_size = bs - elif batch_size != bs: - warning_cache.warn( - "Trying to infer the `batch_size` from an ambiguous collection. The batch size we" - f" found is {batch_size}. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`." - ) - break + try: + for bs in _extract_batch_size(batch): + if batch_size is None: + batch_size = bs + elif batch_size != bs: + warning_cache.warn( + "Trying to infer the `batch_size` from an ambiguous collection. The batch size we" + f" found is {batch_size}. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`." + ) + break + except RecursionError: + raise RecursionError(error_msg) + + if batch_size is None: + raise MisconfigurationException(error_msg) return batch_size diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index f4c61cda64f5d..ae1f8c6505efc 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -28,9 +28,11 @@ def _check_warning_raised(data, expected): assert extract_batch_size(batch) == expected warning_cache.clear() - batch = "test string" - _check_warning_not_raised(batch, 11) + def _check_error_raised(data): + with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"): + extract_batch_size(batch) + # Warning not raised batch = torch.zeros(11, 10, 9, 8) _check_warning_not_raised(batch, 11) @@ -43,6 +45,7 @@ def _check_warning_raised(data, expected): batch = {"test": [{"test": [torch.zeros(11, 10)]}]} _check_warning_not_raised(batch, 11) + # Warning raised batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])} _check_warning_raised(batch, 1) @@ -55,6 +58,20 @@ def _check_warning_raised(data, expected): batch = [{"test": torch.zeros(10, 10), "test_1": torch.zeros(11, 10)}] _check_warning_raised(batch, 10) + # Error raised + batch = "test string" + _check_error_raised(batch) + + data = {"test": ["some text"] * 7} + _check_error_raised(data) + + class CustomBatch: + def __init__(self): + self.x = torch.randn(7, 2) + + data = CustomBatch() + _check_error_raised(data) + def test_has_iterable_dataset(): assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) From 338f3cf63686935355c749920b2f298f3d18a26f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Nov 2021 18:33:47 +0100 Subject: [PATCH 18/59] Use `Set` operations in `Environment.detect` (#10673) --- .../plugins/environments/kubeflow_environment.py | 7 ++++--- pytorch_lightning/plugins/environments/lsf_environment.py | 4 ++-- .../plugins/environments/torchelastic_environment.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/environments/kubeflow_environment.py b/pytorch_lightning/plugins/environments/kubeflow_environment.py index 4682d18270792..03dfdde9d78a0 100644 --- a/pytorch_lightning/plugins/environments/kubeflow_environment.py +++ b/pytorch_lightning/plugins/environments/kubeflow_environment.py @@ -53,10 +53,11 @@ def main_port(self) -> int: @staticmethod def detect() -> bool: """Returns ``True`` if the current process was launched using Kubeflow PyTorchJob.""" - required_env_vars = ("KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK") + required_env_vars = {"KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"} # torchelastic sets these. Make sure we're not in torchelastic - excluded_env_vars = ("GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE") - return all(v in os.environ for v in required_env_vars) and not any(v in os.environ for v in excluded_env_vars) + excluded_env_vars = {"GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"} + env_vars = os.environ.keys() + return required_env_vars.issubset(env_vars) and excluded_env_vars.isdisjoint(env_vars) def world_size(self) -> int: return int(os.environ["WORLD_SIZE"]) diff --git a/pytorch_lightning/plugins/environments/lsf_environment.py b/pytorch_lightning/plugins/environments/lsf_environment.py index 3945082e8b784..c25d068ae01bb 100644 --- a/pytorch_lightning/plugins/environments/lsf_environment.py +++ b/pytorch_lightning/plugins/environments/lsf_environment.py @@ -71,8 +71,8 @@ def main_port(self) -> int: @staticmethod def detect() -> bool: """Returns ``True`` if the current process was launched using the jsrun command.""" - required_env_vars = ("LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE") - return all(v in os.environ for v in required_env_vars) + required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"} + return required_env_vars.issubset(os.environ.keys()) def world_size(self): """The world size is read from the environment variable `JSM_NAMESPACE_SIZE`.""" diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index bb228607e3183..3631f32daa8d4 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -61,8 +61,8 @@ def main_port(self) -> int: @staticmethod def detect() -> bool: """Returns ``True`` if the current process was launched using the torchelastic command.""" - required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE") - return all(v in os.environ for v in required_env_vars) + required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"} + return required_env_vars.issubset(os.environ.keys()) def world_size(self) -> Optional[int]: world_size = os.environ.get("WORLD_SIZE") From 6fc7c54c3af6dc3a446253df379e459d18bacb85 Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Mon, 22 Nov 2021 18:41:08 +0100 Subject: [PATCH 19/59] refactor slurm_job_id (#10622) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton Co-authored-by: Adrian Wälchli Co-authored-by: ananthsub --- CHANGELOG.md | 2 +- .../plugins/environments/slurm_environment.py | 16 ++++++++++++++++ .../logger_connector/logger_connector.py | 3 ++- pytorch_lightning/trainer/trainer.py | 15 +++------------ tests/deprecated_api/test_remove_1-7.py | 6 ++++++ 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fc3178b4426e..46a58ccfb4877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) -- +- Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622)) - diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 53fa4c2a83aa7..ad657e1e19564 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -15,6 +15,7 @@ import logging import os import re +from typing import Optional from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -37,6 +38,21 @@ def __init__(self, auto_requeue: bool = True) -> None: def creates_processes_externally(self) -> bool: return True + @staticmethod + def job_id() -> Optional[int]: + job_id = os.environ.get("SLURM_JOB_ID") + if job_id: + try: + job_id = int(job_id) + except ValueError: + job_id = None + + # in interactive mode, don't make logs use the same job id + in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" + if in_slurm_interactive_mode: + job_id = None + return job_id + @property def main_address(self) -> str: # figure out the root node addr diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d970d98c602bc..b98f13138b36f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -18,6 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType, memory @@ -81,7 +82,7 @@ def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[Lig # default logger self.trainer.logger = ( TensorBoardLogger( - save_dir=self.trainer.default_root_dir, version=self.trainer.slurm_job_id, name="lightning_logs" + save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id(), name="lightning_logs" ) if logger else None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cb446632cc9e3..1ccdb9ecaeca8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -39,6 +39,7 @@ from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -1725,18 +1726,8 @@ def is_global_zero(self) -> bool: @property def slurm_job_id(self) -> Optional[int]: - job_id = os.environ.get("SLURM_JOB_ID") - if job_id: - try: - job_id = int(job_id) - except ValueError: - job_id = None - - # in interactive mode, don't make logs use the same job id - in_slurm_interactive_mode = os.environ.get("SLURM_JOB_NAME") == "bash" - if in_slurm_interactive_mode: - job_id = None - return job_id + rank_zero_deprecation("Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0.") + return SLURMEnvironment.job_id() @property def lightning_optimizers(self) -> List[LightningOptimizer]: diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index c8016961963ba..9c0c12f981f4b 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -378,6 +378,12 @@ def test_v1_7_0_trainer_log_gpu_memory(tmpdir): _ = Trainer(log_gpu_memory="min_max") +def test_v1_7_0_deprecated_slurm_job_id(): + trainer = Trainer() + with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."): + trainer.slurm_job_id + + @RunIf(min_gpus=1) def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir): with pytest.deprecated_call(match="The `GPUStatsMonitor` callback was deprecated in v1.5"): From 15305c459c1fd80619775d4f6e671a1d30202b7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 22 Nov 2021 18:45:34 +0100 Subject: [PATCH 20/59] Update DDPShardedPlugin precision handling after moving PrecisionPlugin (#10658) --- pytorch_lightning/lite/lite.py | 11 +---------- pytorch_lightning/plugins/training_type/sharded.py | 7 +------ 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index f5fdd0221cbe3..4997d7db779e7 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -31,14 +31,7 @@ _LiteOptimizer, _replace_dataloader_init_method, ) -from pytorch_lightning.plugins import ( - DDPShardedPlugin, - DDPSpawnPlugin, - DeepSpeedPlugin, - PLUGIN_INPUT, - TPUSpawnPlugin, - TrainingTypePlugin, -) +from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors @@ -411,8 +404,6 @@ def _set_plugin_specific_precision_variables(self) -> None: # todo: these are hacks as plugins rely on access to the precision plugin if isinstance(self._strategy, DeepSpeedPlugin): self._set_deepspeed_precision_variables() - if isinstance(self._strategy, DDPShardedPlugin): - self._strategy._precision = self._accelerator_connector.precision def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: if isinstance(self._strategy, TPUSpawnPlugin): diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index eb4cb48534708..c9627324eb237 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -39,10 +39,6 @@ class DDPShardedPlugin(DDPPlugin): distributed_backend = _StrategyType.DDP_SHARDED _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._precision = None - def configure_ddp(self) -> None: trainer = self.lightning_module.trainer if "reduce_buffer_size" not in self._ddp_kwargs: @@ -75,8 +71,7 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self._precision or self.precision_plugin.precision - is_fp16 = precision in ("mixed", 16) + is_fp16 = self.precision_plugin.precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. From cd7b4342f68ca98c0c9f6637e3c8eff09f0ad8a0 Mon Sep 17 00:00:00 2001 From: shabie <30535146+shabie@users.noreply.github.com> Date: Mon, 22 Nov 2021 18:54:19 +0100 Subject: [PATCH 21/59] remove import of datasets separately since unused (#10668) --- docs/source/starter/introduction_guide.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 6247a899ed869..9c6bf5a69e205 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -139,7 +139,7 @@ Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIS from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST import os - from torchvision import datasets, transforms + from torchvision import transforms # transforms # prepare transforms standard to MNIST From 6acfef680ffd27ffb78003a49676b17aa5042bfa Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 22 Nov 2021 18:48:32 +0000 Subject: [PATCH 22/59] Fault Tolerant Manual: Add is_obj_stateful utility (#10646) --- CHANGELOG.md | 1 + pytorch_lightning/utilities/auto_restart.py | 12 +++++++++++ tests/utilities/test_auto_restart.py | 24 +++++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46a58ccfb4877..a454739c4e415 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual + * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 228e16e4e9c8c..23583852f4f39 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -22,6 +22,7 @@ import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -570,3 +571,14 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +@runtime_checkable +class _SupportsStateDict(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b9eb97cb42ae8..5152874b39469 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,6 +40,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _SupportsStateDict, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1195,6 +1196,29 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +def test_supports_state_dict_protocol(): + class StatefulClass: + def state_dict(self): + pass + + def load_state_dict(self, state_dict): + pass + + assert isinstance(StatefulClass(), _SupportsStateDict) + + class NotStatefulClass: + def state_dict(self): + pass + + assert not isinstance(NotStatefulClass(), _SupportsStateDict) + + class NotStateful2Class: + def load_state_dict(self, state_dict): + pass + + assert not isinstance(NotStateful2Class(), _SupportsStateDict) + + def test_fault_tolerant_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() From 823bfa6f8a16a8ff2f93c8d5893549424e5b471b Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 23 Nov 2021 01:02:04 +0530 Subject: [PATCH 23/59] Update `LightningModule` docs (#10637) --- docs/source/common/lightning_module.rst | 382 ++++++++++++++++-------- pytorch_lightning/core/hooks.py | 31 +- tests/models/test_hooks.py | 2 +- tests/trainer/test_dataloaders.py | 6 +- 4 files changed, 276 insertions(+), 145 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 166b0b2384461..ca6950fe36e0b 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -5,13 +5,14 @@ LightningModule =============== -A :class:`~LightningModule` organizes your PyTorch code into 5 sections +A :class:`~LightningModule` organizes your PyTorch code into 6 sections: - Computations (init). - Train loop (training_step) - Validation loop (validation_step) - Test loop (test_step) -- Optimizers (configure_optimizers) +- Prediction loop (predict_step) +- Optimizers and LR Schedulers (configure_optimizers) | @@ -23,10 +24,10 @@ A :class:`~LightningModule` organizes your PyTorch code into 5 sections Notice a few things. -1. It's the SAME code. +1. It is the SAME code. 2. The PyTorch code IS NOT abstracted - just organized. 3. All the other code that's not in the :class:`~LightningModule` - has been automated for you by the trainer. + has been automated for you by the Trainer. | @@ -36,13 +37,13 @@ Notice a few things. trainer = Trainer() trainer.fit(net) -4. There are no .cuda() or .to() calls... Lightning does these for you. +4. There are no ``.cuda()`` or ``.to(device)`` calls required. Lightning does these for you. | .. code-block:: python - # don't do in lightning + # don't do in Lightning x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device) @@ -54,7 +55,7 @@ Notice a few things. new_x = torch.Tensor(2, 3) new_x = new_x.type_as(x) -5. Lightning by default handles the distributed sampler for you. +5. When running under a distributed strategy, Lightning handles the distributed sampler for you by default. | @@ -116,10 +117,10 @@ Which you can train by doing: .. code-block:: python train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) - trainer = pl.Trainer() + trainer = pl.Trainer(max_epochs=1) model = LitModel() - trainer.fit(model, train_loader) + trainer.fit(model, train_dataloaders=train_loader) The LightningModule has many convenience methods, but the core ones you need to know about are: @@ -134,11 +135,13 @@ The LightningModule has many convenience methods, but the core ones you need to * - forward - Use for inference only (separate from training_step) * - training_step - - the full training loop + - the complete training loop * - validation_step - - the full validation loop + - the complete validation loop * - test_step - - the full test loop + - the complete test loop + * - predict_step + - the complete prediction loop * - configure_optimizers - define optimizers and LR schedulers @@ -149,7 +152,7 @@ Training Training loop ^^^^^^^^^^^^^ -To add a training loop use the `training_step` method +To activate the training loop, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` method. .. code-block:: python @@ -190,7 +193,7 @@ Under the hood, Lightning does the following (pseudocode): Training epoch-level metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you want to calculate epoch-level metrics and log them, use the `.log` method +If you want to calculate epoch-level metrics and log them, use :meth:`~pytorch_lightning.core.lightning.LightningModule.log`. .. code-block:: python @@ -204,8 +207,8 @@ If you want to calculate epoch-level metrics and log them, use the `.log` method self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss -The `.log` object automatically reduces the requested metrics across the full epoch. -Here's the pseudocode of what it does under the hood: +The :meth:`~pytorch_lightning.core.lightning.LightningModule.log` object automatically reduces the +requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood: .. code-block:: python @@ -228,7 +231,8 @@ Here's the pseudocode of what it does under the hood: Train epoch-level operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you need to do something with all the outputs of each `training_step`, override `training_epoch_end` yourself. +If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. +override the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_epoch_end` method. .. code-block:: python @@ -241,8 +245,8 @@ If you need to do something with all the outputs of each `training_step`, overri def training_epoch_end(self, training_step_outputs): - for pred in training_step_outputs: - ... + all_preds = torch.stack(training_step_outputs) + ... The matching pseudocode is: @@ -267,10 +271,11 @@ The matching pseudocode is: Training with DataParallel ~~~~~~~~~~~~~~~~~~~~~~~~~~ -When training using an `accelerator` that splits data from each batch across GPUs, sometimes you might -need to aggregate them on the main GPU for processing (dp, or ddp2). +When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might +need to aggregate them on the main GPU for processing (DP, or DDP2). -In this case, implement the `training_step_end` method +In this case, implement the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end` +method which will have outputs from all the devices and you can accumulate to get the effective results. .. code-block:: python @@ -299,7 +304,7 @@ In this case, implement the `training_step_end` method for out in training_step_outputs: ... -The full pseudocode that lighting does under the hood is: +Here is the Lightning training pseudo-code for DP: .. code-block:: python @@ -324,7 +329,7 @@ The full pseudocode that lighting does under the hood is: Validation loop ^^^^^^^^^^^^^^^ -To add a validation loop, override the `validation_step` method of the :class:`~LightningModule`: +To activate the validation loop while training, override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method. .. code-block:: python @@ -335,7 +340,7 @@ To add a validation loop, override the `validation_step` method of the :class:`~ loss = F.cross_entropy(y_hat, y) self.log("val_loss", loss) -Under the hood, Lightning does the following: +Under the hood, Lightning does the following (pseudocode): .. code-block:: python @@ -359,9 +364,19 @@ Under the hood, Lightning does the following: torch.set_grad_enabled(True) model.train() +You can also run just the validation loop on your validation dataloaders by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` +and calling :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. + +.. code-block:: python + + model = Model() + trainer = Trainer() + trainer.validate(model) + Validation epoch-level metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you need to do something with all the outputs of each `validation_step`, override `validation_epoch_end`. +If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`, +override the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` method. .. code-block:: python @@ -374,15 +389,16 @@ If you need to do something with all the outputs of each `validation_step`, over def validation_epoch_end(self, validation_step_outputs): - for pred in validation_step_outputs: - ... + all_preds = torch.stack(validation_step_outputs) + ... Validating with DataParallel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -When training using an `accelerator` that splits data from each batch across GPUs, sometimes you might -need to aggregate them on the main GPU for processing (dp, or ddp2). +When training using a ``strategy`` that splits data from each batch across GPUs, sometimes you might +need to aggregate them on the main GPU for processing (DP, or DDP2). -In this case, implement the `validation_step_end` method +In this case, implement the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step_end` +method which will have outputs from all the devices and you can accumulate to get the effective results. .. code-block:: python @@ -411,7 +427,7 @@ In this case, implement the `validation_step_end` method for out in validation_step_outputs: ... -The full pseudocode that lighting does under the hood is: +Here is the Lightning validation pseudo-code for DP: .. code-block:: python @@ -436,21 +452,21 @@ The full pseudocode that lighting does under the hood is: Test loop ^^^^^^^^^ -The process for adding a test loop is the same as the process for adding a validation loop. Please refer to -the section above for details. +The process for enabling a test loop is the same as the process for enabling a validation loop. Please refer to +the section above for details. For this you need to override the :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` method. -The only difference is that the test loop is only called when `.test()` is used: +The only difference is that the test loop is only called when :meth:`~pytorch_lightning.trainer.trainer.Trainer.test` is used. .. code-block:: python model = Model() trainer = Trainer() - trainer.fit() + trainer.fit(model) # automatically loads the best weights for you trainer.test(model) -There are two ways to call `test()`: +There are two ways to call ``test()``: .. code-block:: python @@ -458,7 +474,7 @@ There are two ways to call `test()`: trainer = Trainer() trainer.fit(model) - # automatically auto-loads the best weights + # automatically auto-loads the best weights from the previous run trainer.test(dataloaders=test_dataloader) # or call with pretrained model @@ -468,8 +484,8 @@ There are two ways to call `test()`: ---------- -Inference ---------- +Inference (Prediction Loop) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ For research, LightningModules are best structured as systems. .. code-block:: python @@ -531,7 +547,7 @@ This simple model generates examples that look like this (the encoders and decod .. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/ae_docs.png :width: 300 -The methods above are part of the lightning interface: +The methods above are part of the LightningModule interface: - training_step - validation_step @@ -539,7 +555,7 @@ The methods above are part of the lightning interface: - predict_step - configure_optimizers -Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code. +Note that in this case, the train loop and val loop are exactly the same. We can, of course, reuse this code. .. code-block:: python @@ -574,11 +590,11 @@ Note that in this case, the train loop and val loop are exactly the same. We can def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0002) -We create a new method called `shared_step` that all loops can use. This method name is arbitrary and NOT reserved. +We create a new method called ``shared_step`` that all loops can use. This method name is arbitrary and NOT reserved. Inference in research ^^^^^^^^^^^^^^^^^^^^^ -In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule. +If you want to perform inference with the system, you can add a ``forward`` method to the LightningModule. .. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager. @@ -633,8 +649,7 @@ For cases like production, you might want to iterate different models inside a L .. code-block:: python - import pytorch_lightning as pl - from pytorch_lightning.metrics import functional as FM + from torchmetrics.functional import accuracy class ClassificationTask(pl.LightningModule): @@ -664,12 +679,13 @@ For cases like production, you might want to iterate different models inside a L x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) - acc = FM.accuracy(y_hat, y) + acc = accuracy(y_hat, y) return loss, acc def predict_step(self, batch, batch_idx, dataloader_idx=0): x, y = batch y_hat = self.model(x) + return y_hat def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=0.02) @@ -682,7 +698,7 @@ Then pass in any arbitrary model to be fit with this task task = ClassificationTask(model) trainer = Trainer(gpus=2) - trainer.fit(task, train_dataloader, val_dataloader) + trainer.fit(task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) Tasks can be arbitrarily complex such as implementing GAN training, self-supervised or even RL. @@ -697,21 +713,22 @@ Tasks can be arbitrarily complex such as implementing GAN training, self-supervi ... When used like this, the model can be separated from the Task and thus used in production without needing to keep it in -a `LightningModule`. +a ``LightningModule``. -- You can export to onnx. -- Or trace using Jit. -- or run in the python runtime. +- You can export to onnx using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_onnx`. +- Or trace using Jit using :meth:`~pytorch_lightning.core.lightning.LightningModule.to_torchscript`. +- Or run in the Python runtime. .. code-block:: python - task = ClassificationTask(model) + task = ClassificationTask(model) - trainer = Trainer(gpus=2) - trainer.fit(task, train_dataloader, val_dataloader) + trainer = Trainer(gpus=2) + trainer.fit(task, train_dataloader, val_dataloader) - # use model after training or load weights and drop into the production system - model.eval() + # use model after training or load weights and drop into the production system + model.eval() + with torch.no_grad(): y_hat = model(x) ----------- @@ -722,6 +739,12 @@ LightningModule API Methods ^^^^^^^ +all_gather +~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.all_gather + :noindex: + configure_callbacks ~~~~~~~~~~~~~~~~~~~ @@ -758,12 +781,24 @@ log_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict :noindex: +lr_schedulers +~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.lr_schedulers + :noindex: + manual_backward ~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward :noindex: +optimizers +~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.optimizers + :noindex: + print ~~~~~ @@ -782,6 +817,12 @@ save_hyperparameters .. automethod:: pytorch_lightning.core.lightning.LightningModule.save_hyperparameters :noindex: +toggle_optimizer +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.toggle_optimizer + :noindex: + test_step ~~~~~~~~~ @@ -835,6 +876,12 @@ unfreeze .. automethod:: pytorch_lightning.core.lightning.LightningModule.unfreeze :noindex: +untoggle_optimizer +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.untoggle_optimizer + :noindex: + validation_step ~~~~~~~~~~~~~~~ @@ -875,7 +922,7 @@ The current epoch device ~~~~~~ -The device the module is on. Use it to keep your code device agnostic +The device the module is on. Use it to keep your code device agnostic. .. code-block:: python @@ -886,11 +933,16 @@ The device the module is on. Use it to keep your code device agnostic global_rank ~~~~~~~~~~~ -The global_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You -normally do not need to use this property +The ``global_rank`` is the index of the current process across all nodes and devices. +Lightning will perform some operations such as logging, weight checkpointing only when ``global_rank=0``. You +usually do not need to use this property, but it is useful to know how to access it if needed. -Global rank refers to the index of that GPU across ALL GPUs. For example, if using 10 machines, each with 4 GPUs, -the 4th GPU on the 10th machine has global_rank = 39 +.. code-block:: python + + def training_step(self): + if self.global_rank == 0: + # do something only once across all the nodes + self.log("global_step", self.trainer.global_step) ------------- @@ -907,8 +959,8 @@ The current step (does not reset each epoch) hparams ~~~~~~~ -The arguments saved by calling ``save_hyperparameters`` passed through ``__init__()`` - could be accessed by the ``hparams`` attribute. +The arguments passed through ``LightningModule.__init__()`` and saved by calling +:meth:`~pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin.save_hyperparameters` could be accessed by the ``hparams`` attribute. .. code-block:: python @@ -938,12 +990,16 @@ The current logger being used (tensorboard or other supported logger) local_rank ~~~~~~~~~~~ -The local_rank of this LightningModule. Lightning saves logs, weights etc only from global_rank = 0. You -normally do not need to use this property +The ``global_rank`` is the index of the current process across all the devices for the current node. +You usually do not need to use this property, but it is useful to know how to access it if needed. +For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0. -Local rank refers to the rank on that machine. For example, if using 10 machines, the GPU at index 0 on each machine -has local_rank = 0. +.. code-block:: python + def training_step(self): + if self.global_rank == 0: + # do something only once across each node + self.log("global_step", self.trainer.global_step) ----------- @@ -973,7 +1029,7 @@ Pointer to the trainer use_amp ~~~~~~~ -True if using Automatic Mixed Precision (AMP) +``True`` if using Automatic Mixed Precision (AMP) -------------- @@ -1026,7 +1082,7 @@ Manual optimization is most useful for research topics like reinforcement learni example_input_array ~~~~~~~~~~~~~~~~~~~ -Set and access example_input_array which is basically a single batch. +Set and access example_input_array, which basically represents a single batch. .. code-block:: python @@ -1062,15 +1118,15 @@ Get the model file size (in megabytes) using ``self.model_size`` inside Lightnin truncated_bptt_steps ^^^^^^^^^^^^^^^^^^^^ -Truncated back prop breaks performs backprop every k steps of +Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of a much longer sequence. This is made possible by passing training batches -splitted along the time-dimensions into splits of size k to the +split along the time-dimensions into splits of size k to the ``training_step``. In order to keep the same forward propagation behavior, all hidden states should be kept in-between each time-dimension split. If this is enabled, your batches will automatically get truncated -and the trainer will apply Truncated Backprop to it. +and the Trainer will apply Truncated Backprop to it. (`Williams et al. "An efficient gradient-based algorithm for on-line training of recurrent network trajectories." @@ -1114,7 +1170,7 @@ recurrent network trajectories." Lightning takes care of splitting your batch along the time-dimension. It is assumed to be the second dimension of your batches. Therefore, in the -example above we have set ``batch_first=True``. +example above, we have set ``batch_first=True``. .. code-block:: python @@ -1123,7 +1179,7 @@ example above we have set ``batch_first=True``. sub_batch = batch[0, 0:t, ...] To modify how the batch is split, -override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: +override the :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch` method: .. testcode:: python @@ -1246,211 +1302,271 @@ backward on_before_backward ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_backward +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_backward :noindex: on_after_backward ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward :noindex: on_before_zero_grad ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad :noindex: on_fit_start ~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_fit_start :noindex: on_fit_end ~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_fit_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_fit_end :noindex: on_load_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint :noindex: on_save_checkpoint ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint + :noindex: + +load_from_checkpoint +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.load_from_checkpoint + :noindex: + +on_hpc_save +~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_hpc_save + :noindex: + +on_hpc_load +~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_hpc_load :noindex: on_train_start ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_start :noindex: on_train_end ~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_end :noindex: on_validation_start ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_start :noindex: on_validation_end ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end :noindex: on_pretrain_routine_start ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_start :noindex: on_pretrain_routine_end ~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_pretrain_routine_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_end :noindex: on_test_batch_start ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_batch_start :noindex: on_test_batch_end ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_batch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_batch_end :noindex: on_test_epoch_start ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_epoch_start :noindex: on_test_epoch_end ~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_epoch_end :noindex: on_test_start ~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_start :noindex: on_test_end ~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_end + :noindex: + +on_predict_batch_start +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_batch_start + :noindex: + +on_predict_batch_end +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_batch_end + :noindex: + +on_predict_epoch_start +~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_epoch_start + :noindex: + +on_predict_epoch_end +~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_epoch_end + :noindex: + +on_predict_start +~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_start + :noindex: + +on_predict_end +~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_end :noindex: on_train_batch_start ~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_start :noindex: on_train_batch_end ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_end :noindex: on_epoch_start ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_start :noindex: on_epoch_end ~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_end :noindex: on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_epoch_start :noindex: on_train_epoch_end ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_epoch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_epoch_end :noindex: on_validation_batch_start ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_batch_start :noindex: on_validation_batch_end ~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_batch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_batch_end :noindex: on_validation_epoch_start ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_start +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_epoch_start :noindex: on_validation_epoch_end ~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_epoch_end :noindex: on_post_move_to_device ~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_post_move_to_device + :noindex: + +configure_sharded_model +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_sharded_model :noindex: on_validation_model_eval ~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_model_eval :noindex: on_validation_model_train ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_model_train :noindex: on_test_model_eval ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_model_eval :noindex: on_test_model_train ~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_model_train :noindex: on_before_optimizer_step ~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_optimizer_step :noindex: configure_gradient_clipping @@ -1480,7 +1596,7 @@ prepare_data setup ~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup +.. automethod:: pytorch_lightning.core.lightning.LightningModule.setup :noindex: tbptt_split_batch @@ -1492,43 +1608,73 @@ tbptt_split_batch teardown ~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown +.. automethod:: pytorch_lightning.core.lightning.LightningModule.teardown :noindex: train_dataloader ~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader +.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader :noindex: val_dataloader ~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader +.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader :noindex: test_dataloader ~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader +.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader + :noindex: + +predict_dataloader +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.predict_dataloader + :noindex: + +on_train_dataloader +~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_dataloader + :noindex: + +on_val_dataloader +~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_val_dataloader + :noindex: + +on_test_dataloader +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_test_dataloader + :noindex: + +on_predict_dataloader +~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_predict_dataloader :noindex: transfer_batch_to_device ~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device +.. automethod:: pytorch_lightning.core.lightning.LightningModule.transfer_batch_to_device :noindex: on_before_batch_transfer ~~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_before_batch_transfer +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_batch_transfer :noindex: on_after_batch_transfer ~~~~~~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.DataHooks.on_after_batch_transfer +.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_batch_transfer :noindex: add_to_queue diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 376a6919ca43f..5263c16952fec 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -423,11 +423,9 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: .. warning:: do not assign state in prepare_data - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` - :meth:`prepare_data` - :meth:`setup` - - :meth:`train_dataloader` Note: Lightning adds the correct sampler for distributed and arbitrary hardware. @@ -480,10 +478,6 @@ def test_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for testing. - The dataloader you return will not be reloaded unless you set - :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to - a postive integer. - For data processing use the following pattern: - download in :meth:`prepare_data` @@ -494,13 +488,9 @@ def test_dataloader(self) -> EVAL_DATALOADERS: .. warning:: do not assign state in prepare_data - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.test` - :meth:`prepare_data` - :meth:`setup` - - :meth:`train_dataloader` - - :meth:`val_dataloader` - - :meth:`test_dataloader` Note: Lightning adds the correct sampler for distributed and arbitrary hardware. @@ -548,12 +538,10 @@ def val_dataloader(self) -> EVAL_DATALOADERS: It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` - :meth:`prepare_data` - - :meth:`train_dataloader` - - :meth:`val_dataloader` - - :meth:`test_dataloader` + - :meth:`setup` Note: Lightning adds the correct sampler for distributed and arbitrary hardware @@ -597,12 +585,9 @@ def predict_dataloader(self) -> EVAL_DATALOADERS: It's recommended that all data downloads and preparation happen in :meth:`prepare_data`. - - :meth:`~pytorch_lightning.trainer.Trainer.fit` - - ... + - :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict` - :meth:`prepare_data` - - :meth:`train_dataloader` - - :meth:`val_dataloader` - - :meth:`test_dataloader` + - :meth:`setup` Note: Lightning adds the correct sampler for distributed and arbitrary hardware @@ -612,7 +597,7 @@ def predict_dataloader(self) -> EVAL_DATALOADERS: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. Note: - In the case where you return multiple prediction dataloaders, the :meth:`predict` + In the case where you return multiple prediction dataloaders, the :meth:`predict_step` will have an argument ``dataloader_idx`` which matches the order here. """ raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer") diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 19c4e71d54fc4..35b50acfcef4f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -866,7 +866,7 @@ def call(hook, fn, *args, **kwargs): limit_predict_batches=batches, enable_progress_bar=False, enable_model_summary=False, - reload_dataloaders_every_n_epochs=True, + reload_dataloaders_every_n_epochs=1, ) called = [] diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 59c88e972d6a6..da7e0704cd4e2 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1127,7 +1127,7 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): limit_train_batches=10, limit_val_batches=10, val_check_interval=0.3, - reload_dataloaders_every_n_epochs=True, + reload_dataloaders_every_n_epochs=1, max_epochs=3, ) @@ -1245,7 +1245,7 @@ def validation_step(self, batch, batch_idx): limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, - reload_dataloaders_every_n_epochs=True, + reload_dataloaders_every_n_epochs=1, max_epochs=3, callbacks=[checkpoint_callback], ) @@ -1272,7 +1272,7 @@ def validation_step(self, batch, batch_idx): # the val dataloader on the first epoch because this only tracks the training epoch # meaning multiple passes through the validation data within a single training epoch # would not have the dataloader reloaded. - # This breaks the assumption behind reload_dataloaders_every_n_epochs=True + # This breaks the assumption behind reload_dataloaders_every_n_epochs=1 call.val_dataloader(), call.train_dataloader(), call.val_dataloader(), From 2036dfb5df2a48056d1c043ee311aed793187070 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 22 Nov 2021 19:52:04 +0000 Subject: [PATCH 24/59] Fault Tolerant Manual: Add _rotate_worker_indices utility (#10647) --- CHANGELOG.md | 1 + pytorch_lightning/utilities/auto_restart.py | 23 ++++++++++++++------- tests/utilities/test_auto_restart.py | 14 +++++++++++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a454739c4e415..adb1b070dc386 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) + * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) - diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 23583852f4f39..4cb1793643c1d 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -247,14 +247,7 @@ def __len__(self) -> int: def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: # as workers aren't available, the ``state_dict``` is cached until workers are made available. state_dict = deepcopy(state_dict) - - if num_workers > 0: - # remap states to worker ids starting at 0 - next_worker_id = latest_worker_id + 1 - old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] - state_dict = { - new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict - } + state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers) self._cached_state_dict = state_dict def state_dict(self) -> Dict[int, Dict[str, Any]]: @@ -573,6 +566,20 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") +def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: + """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed + on.""" + if num_workers == 0: + return state + if latest_worker_id > num_workers - 1: + raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].") + if len(state) != num_workers: + raise MisconfigurationException("The `state` should contain `num_workers - 1` values.") + next_worker_id = latest_worker_id + 1 + old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] + return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} + + @runtime_checkable class _SupportsStateDict(Protocol): """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5152874b39469..d9063f90db377 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,6 +40,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _rotate_worker_indices, _SupportsStateDict, CaptureIterableDataset, CaptureMapDataset, @@ -1196,6 +1197,19 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +def test_rotate_worker_indices(): + """This test ensures `worker_id` are rotated properly depending on which one was the latest.""" + state_dict = {0: 0, 1: 1} + assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0} + assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1} + + with pytest.raises(MisconfigurationException, match="The `latest_worker_id` should be within"): + _rotate_worker_indices(state_dict, 2, 2) + + with pytest.raises(MisconfigurationException, match="The `state` should contain"): + _rotate_worker_indices(state_dict, 2, 3) + + def test_supports_state_dict_protocol(): class StatefulClass: def state_dict(self): From 48cf1adfd3ad9c7e659083a4afc334dafb331f28 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 23 Nov 2021 11:46:31 +0530 Subject: [PATCH 25/59] Move Colab setup to ProgressBar (#10542) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pytorch_lightning/callbacks/progress/tqdm_progress.py | 11 ++++++++++- .../trainer/connectors/callback_connector.py | 3 +-- tests/callbacks/test_tqdm_progress_bar.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 672d9d893ad61..9de6770c6c23d 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -26,6 +26,7 @@ from tqdm import tqdm as _tqdm from pytorch_lightning.callbacks.progress.base import ProgressBarBase +from pytorch_lightning.utilities.distributed import rank_zero_debug _PAD_SIZE = 5 @@ -100,7 +101,7 @@ class TQDMProgressBar(ProgressBarBase): def __init__(self, refresh_rate: int = 1, process_position: int = 0): super().__init__() - self._refresh_rate = refresh_rate + self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position self._enabled = True self.main_progress_bar = None @@ -324,6 +325,14 @@ def _update_bar(self, bar: Optional[Tqdm]) -> None: if delta > 0: bar.update(delta) + @staticmethod + def _resolve_refresh_rate(refresh_rate: int) -> int: + if os.getenv("COLAB_GPU") and refresh_rate == 1: + # smaller refresh rate on colab causes crashes, choose a higher value + rank_zero_debug("Using a higher refresh rate on Colab. Setting it to `20`") + refresh_rate = 20 + return refresh_rate + def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """The tqdm doesn't support inf/nan values. diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 55730c0d79c72..a662d42b2b1af 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -238,8 +238,7 @@ def _configure_progress_bar( if refresh_rate == 0 or not enable_progress_bar: return if refresh_rate is None: - # smaller refresh rate on colab causes crashes, choose a higher value - refresh_rate = 20 if os.getenv("COLAB_GPU") else 1 + refresh_rate = 1 progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) self.trainer.callbacks.append(progress_bar_callback) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 89928c17a6de4..1ff1a602fe3b6 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -274,7 +274,7 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): assert trainer.progress_bar_callback.refresh_rate == 20 trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) - assert trainer.progress_bar_callback.refresh_rate == 1 # FIXME: should be 20 + assert trainer.progress_bar_callback.refresh_rate == 20 trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19)) assert trainer.progress_bar_callback.refresh_rate == 19 From 1702036c14a2e77f4a47df003cd9937d84fc7c9e Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 23 Nov 2021 12:30:50 +0000 Subject: [PATCH 26/59] Fault Tolerant Manual: Add stateful dataloader iter (#10674) --- CHANGELOG.md | 1 + pytorch_lightning/trainer/data_loading.py | 13 +- pytorch_lightning/utilities/auto_restart.py | 153 ++++++++++++++++++-- pytorch_lightning/utilities/fetching.py | 12 ++ tests/utilities/test_auto_restart.py | 100 +++++++++++++ 5 files changed, 259 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index adb1b070dc386..fd9cf54a9730e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) + * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) - diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d7354a8294b37..6044f1320286c 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -15,7 +15,6 @@ import os from abc import ABC from copy import deepcopy -from functools import partial from typing import Any, Callable, Collection, List, Optional, Tuple, Union from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler @@ -29,7 +28,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import _capture_metadata_collate +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _update_dataloader, @@ -215,7 +214,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): - apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) @@ -437,14 +436,6 @@ def request_dataloader( self.training_type_plugin.barrier("get_dataloaders") return dataloader - @staticmethod - def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: - """Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is - enabled.""" - dataloader.collate_fn = partial( - _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn - ) - @staticmethod def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: all_have_sequential_sampler = True diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4cb1793643c1d..4e984f7ecb2aa 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -21,11 +21,17 @@ import numpy as np import torch from torch.utils.data import Dataset, get_worker_info, Sampler -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from torch.utils.data.dataloader import ( + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _SingleProcessDataLoaderIter, + DataLoader, + IterableDataset, +) from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -435,8 +441,10 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: return {"num_workers": num_workers, "previous_worker": previous_worker} -def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: - """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or +def _capture_metadata_collate( + samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode +) -> Any: + """A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker processes. The structure will be: @@ -447,10 +455,25 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ - data = default_collate(samples) - if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)): - return data - metadata = dataset.state_dict() + data = collate_fn(samples) + metadata = None + if fault_tolerant_mode.is_automatic: + metadata = dataset.state_dict() + else: + state_dict_fn = getattr(dataset, "state_dict", None) + info = get_worker_info() + worker_id = info.id if info else 0 + if state_dict_fn is not None: + metadata = state_dict_fn() + if worker_id not in metadata: + if info and info.num_workers > 1: + raise MisconfigurationException( + f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys." + ) + metadata = {0: metadata} + if metadata is None: + metadata = {worker_id: {}} + return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} @@ -480,6 +503,9 @@ def patch_dataloader_iterator( will extract the current iteration as part of the metadata returned by a custom batch. """ + if not _FaultTolerantMode.detect_current_mode().is_automatic: + return + assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: @@ -527,8 +553,14 @@ def wrapper(): def _add_capture_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" + faut_tolerant_mode = _FaultTolerantMode.detect_current_mode() + if not faut_tolerant_mode.is_enabled: + return dataloader.collate_fn = partial( - _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + _capture_metadata_collate, + dataset=dataloader.dataset, + collate_fn=dataloader.collate_fn, + fault_tolerant_mode=faut_tolerant_mode, ) @@ -589,3 +621,106 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... + + +class _StatefulDataLoaderIter: + """This mixin is used to make PyTorch DataLoaderIter stateful.""" + + def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: + # store sampler state within a queue alongside its idx. + self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 + self._sampler_state.append((sampler_state, self._sampler_state_idx)) + + def _store_sampler_state(self) -> None: + """This function is used to extract the sampler states if any.""" + sampler_state = { + k: v.state_dict() + for k, v in self._loader.__dict__.items() + if isinstance(v, _SupportsStateDict) and k != "dataset" + } + + self.__accumulate_state(sampler_state) + + def _next_index(self) -> Any: + indexes = super()._next_index() + self._store_sampler_state() + return indexes + + def _prepare_loader(self, loader): + if not isinstance(loader.collate_fn, partial): + loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn) + self._loader = loader + self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher + self.num_batches_fetched = 0 + self._sampler_state = [] + self._sampler_state_idx = 0 + + def __del__(self) -> None: + if isinstance(self._loader.collate_fn, partial): + self._loader.collate_fn = self._loader.collate_fn.keywords["collate_fn"] + + def _next_data(self) -> Any: + combined_batch = super()._next_data() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] + + self.num_batches_fetched += 1 + + sampler_state, sampler_state_idx = self._sampler_state.pop(0) + # there is no workers within the samplers + worker_id = list(state.keys())[0] + + state = [ + IteratorState( + num_workers=self._loader.num_workers, + sampler_state=sampler_state, + dataset_state=state, + worker_id=worker_id, + num_batches_fetched=self.num_batches_fetched, + ) + ] + # ensures there is an alignement between the sampler state and currently fetched batch + assert sampler_state_idx == self.num_batches_fetched + self._data_fetcher._store_dataloader_iter_state(self, state) + return batch + + +class _SingleProcessDataLoaderIterStateful(_StatefulDataLoaderIter, _SingleProcessDataLoaderIter): + def __init__(self, loader: DataLoader): + self._prepare_loader(loader) + super().__init__(loader) + + +class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProcessingDataLoaderIter): + def __init__(self, loader: DataLoader): + self._prepare_loader(loader) + super().__init__(loader) + + +def _get_iterator(self) -> "_BaseDataLoaderIter": + if not hasattr(self, "_lightning_fetcher"): + raise MisconfigurationException( + "A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader." + ) + if self.num_workers == 0: + return _SingleProcessDataLoaderIterStateful(self) + else: + if hasattr(self, "check_worker_number_rationality"): + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterStateful(self) + + +def _patch_dataloader_get_iterators() -> None: + """This function is used to replace the DataLoader iterator by their stateful version.""" + if not hasattr(DataLoader, "_ori_get_iterator"): + DataLoader._ori_get_iterator = DataLoader._get_iterator + DataLoader._get_iterator = _get_iterator + + +def _teardown_dataloader_get_iterators() -> None: + """This function is used to restore the DataLoader `get_iterator` with its original one.""" + # cleanup the get_iterator replacement in case of Fault Tolerant Training. + get_iterator = getattr(DataLoader, "_ori_get_iterator", None) + if get_iterator: + DataLoader._get_iterator = get_iterator + del DataLoader._ori_get_iterator diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 9b80d2f9874c7..f5bb4be032d10 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -99,6 +99,8 @@ def setup( if self.profiler is not None and stage is None: raise MisconfigurationException("When providing a profiler, the stage should be provided too.") + self._attach_data_fetcher() + @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: if not isinstance(dataloader, (DataLoader, CombinedLoader)): @@ -190,6 +192,16 @@ def collect_state(iterator: Iterator): return apply_to_collection(self.loader_iters, Iterator, collect_state) + def _attach_data_fetcher(self): + def _attach_data_fetcher_fn(loader: DataLoader): + if isinstance(loader, CycleIterator): + loader = loader.loader + + if isinstance(loader, DataLoader) and _fault_tolerant_training(): + loader._lightning_fetcher = self + + apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn) + def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: if self.dataloader is None: raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d9063f90db377..25f80ec6817a5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,8 +40,12 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _MultiProcessingDataLoaderIterStateful, + _patch_dataloader_get_iterators, _rotate_worker_indices, + _SingleProcessDataLoaderIterStateful, _SupportsStateDict, + _teardown_dataloader_get_iterators, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -245,8 +249,10 @@ def __next__(self): return self.data[next(iter_sampler)] +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI") @pytest.mark.parametrize("num_workers", [0, 1, 2]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_fast_forward_sampler_over_iterable_dataset(num_workers): """This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture workers states.""" @@ -626,11 +632,13 @@ def all_gather(tensor, world_size): assert torch.equal(t, tr) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 45 sec and should be skipped in Azure CI") def test_fast_forward_sampler_iterative_dataset(): _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(0, 1) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI") @RunIf(skip_windows=True) def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(): @@ -1251,3 +1259,95 @@ def test_fault_tolerant_mode_enum(): ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): _FaultTolerantMode.detect_current_mode() + + +class StatefulRandomSampler(RandomSampler): + + counter = 0 + + def state_dict(self): + self.counter += 1 + return {"counter": self.counter} + + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] + + +class StatefulRandomDataset(RandomDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 + + def __getitem__(self, index): + self.counter += 1 + return super().__getitem__(index) + + def state_dict(self): + info = get_worker_info() + if info: + return {info.id: {"counter": self.counter}} + return {"counter": self.counter} + + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] + + +@pytest.mark.parametrize("num_workers", [0]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) +def test_stateful_workers(num_workers): + + seed_everything(42) + + _get_iterator_fn = DataLoader._get_iterator + _patch_dataloader_get_iterators() + assert DataLoader._ori_get_iterator is not None + + data_fetcher = DataFetcher() + dataset = StatefulRandomDataset(1, 64) + dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) + + with pytest.raises(MisconfigurationException, match="A stateful iterator should be used"): + iter(dataloader) + + # This would attach the `data_fetcher` to the DataLoader. + data_fetcher.setup(dataloader) + + data_fetcher_iter = iter(data_fetcher) + + dataloader_iter = data_fetcher.dataloader_iter + worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful + assert isinstance(dataloader_iter, worker_type) + + next(data_fetcher_iter) + state = data_fetcher.dataloader_iter.state.state + assert state[0].dataset_state == {0: {"counter": 1}} + assert state[0].sampler_state["sampler"] == {"counter": 1} + + next(data_fetcher_iter) + previous_state = data_fetcher.dataloader_iter.previous_state.state + state = data_fetcher.dataloader_iter.state.state + assert previous_state[0].dataset_state == {0: {"counter": 1}} + assert previous_state[0].sampler_state["sampler"] == {"counter": 1} + # TODO: Resolve the previous `sampler_state` associated to `worker_id: 0`. + worker_id = 1 if num_workers else 0 + assert state[worker_id].sampler_state["sampler"] == {"counter": 2} + + # each worker has its own copy of the dataset + assert state[0].dataset_state == ({0: {"counter": 2}} if num_workers == 0 else {0: {"counter": 1}}) + target_previous_state = deepcopy(state) + + next(data_fetcher_iter) + latest_worker_id = data_fetcher.dataloader_iter.state.latest_worker_id + assert latest_worker_id == 0 + previous_state = data_fetcher.dataloader_iter.previous_state.state + state = data_fetcher.dataloader_iter.state.state + + assert target_previous_state == previous_state + assert state[0].sampler_state["sampler"] == {"counter": 3} + assert state[0].dataset_state == ({0: {"counter": 3}} if num_workers == 0 else {0: {"counter": 2}}) + + _teardown_dataloader_get_iterators() + assert not hasattr(DataLoader, "_ori_get_iterator") + assert DataLoader._get_iterator == _get_iterator_fn + + data_fetcher.teardown() From ee9f7c0421b83f0d428a7943361f465c8978d4d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Nov 2021 14:51:41 +0100 Subject: [PATCH 27/59] Update DeepSpeed precision handling after moving PrecisionPlugin (#10657) --- CHANGELOG.md | 4 ++ pytorch_lightning/lite/lite.py | 13 ------ .../plugins/precision/deepspeed.py | 6 ++- .../plugins/training_type/deepspeed.py | 44 +++++-------------- .../connectors/accelerator_connector.py | 2 +- 5 files changed, 21 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd9cf54a9730e..9354ca6cdee5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,6 +163,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) + +- Removed `DeepSpeedPlugin.{precision,amp_type,amp_level}` properties ([#10657](https://github.com/PyTorchLightning/pytorch-lightning/pull/10657)) + + ### Fixed - When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076)) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 4997d7db779e7..6c41c80a56171 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -385,7 +385,6 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) return seed_everything(seed=seed, workers=workers) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - self._set_plugin_specific_precision_variables() self._accelerator.setup_environment() # apply sharded context to prevent OOM @@ -400,11 +399,6 @@ def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) - def _set_plugin_specific_precision_variables(self) -> None: - # todo: these are hacks as plugins rely on access to the precision plugin - if isinstance(self._strategy, DeepSpeedPlugin): - self._set_deepspeed_precision_variables() - def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: if isinstance(self._strategy, TPUSpawnPlugin): # When the user creates the optimizer, they reference the parameters on the CPU. @@ -423,13 +417,6 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) - model = self.to_device(model) return model - def _set_deepspeed_precision_variables(self) -> None: - # TODO: Refactor this once precision pluging is part of the strategy. - amp_type = self._accelerator_connector.amp_type - amp_level = self._accelerator_connector.amp_level - precision = self._accelerator_connector.precision - self._strategy._amp_level, self._strategy._amp_type, self._strategy._precision = amp_level, amp_type, precision - def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( self._accelerator_connector.is_distributed diff --git a/pytorch_lightning/plugins/precision/deepspeed.py b/pytorch_lightning/plugins/precision/deepspeed.py index 27ac384d25303..46cf023fc5d32 100644 --- a/pytorch_lightning/plugins/precision/deepspeed.py +++ b/pytorch_lightning/plugins/precision/deepspeed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union from torch import Tensor from torch.nn import Module @@ -34,9 +34,11 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): """Precision plugin for DeepSpeed integration.""" - def __init__(self, precision: int) -> None: + def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: super().__init__() self.precision = precision + self.amp_type = amp_type + self.amp_level = amp_level def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: if is_overridden("backward", model): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 01959bdcee212..86d380ac24ce8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -34,10 +34,10 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only -from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.distributed import log, rank_zero_info +from pytorch_lightning.utilities.enums import _StrategyType, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden @@ -327,24 +327,6 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale - # optionally set by Lite - self._precision: Optional[Union[str, int]] = None - self._amp_level: Optional[str] = None - self._amp_type: Optional[str] = None - - @property - def precision(self) -> Union[str, int]: - return self._precision or self.precision_plugin.precision - - @property - def amp_level(self) -> Optional[str]: - if self._amp_type == AMPType.APEX: - return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level - - @property - def amp_type(self) -> Optional[str]: - return self._amp_type or self.lightning_module.trainer._accelerator_connector.amp_type - def _load_config(self, config): if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") @@ -459,11 +441,11 @@ def init_deepspeed(self): "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) - model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision) + model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision) if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device - dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 deepspeed.zero.Init( module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -520,7 +502,7 @@ def _initialize_deepspeed_train(self, model): def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 model_parallel_context = deepspeed.zero.Init( remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -646,11 +628,9 @@ def _auto_select_batch_size(self): ) return batch_size - def _format_precision_config(self): - if self.amp_type == AMPType.APEX: - amp_level = self.amp_level - if self.precision in (16, "mixed"): - if "fp16" not in self.config and self.amp_type == AMPType.NATIVE: + def _format_precision_config(self) -> None: + if self.precision_plugin.precision in (16, "mixed"): + if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") self.config["fp16"] = { @@ -661,9 +641,9 @@ def _format_precision_config(self): "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "amp" not in self.config and self.amp_type == AMPType.APEX: - rank_zero_only("Enabling DeepSpeed APEX Implementation.") - self.config["amp"] = {"enabled": True, "opt_level": amp_level} + elif "amp" not in self.config and self.precision_plugin.amp_type == AMPType.APEX: + rank_zero_info("Enabling DeepSpeed APEX Implementation.") + self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level} def _create_default_config( self, diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7136437bbc69d..c95d46e77b977 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -637,7 +637,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): - return DeepSpeedPrecisionPlugin(self.precision) + return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin() From 7cf6374bd042018323df06b188baaa1434898c14 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 23 Nov 2021 14:27:33 +0000 Subject: [PATCH 28/59] Fault Tolerant Manual: Add support for collecting states across processes (#10639) --- CHANGELOG.md | 1 + pytorch_lightning/utilities/distributed.py | 27 ++++++++++++++++++- tests/utilities/test_distributed.py | 30 ++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9354ca6cdee5a..da15315e6a7f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) + * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) - diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1740518923c0f..7c6e4f4048181 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os from functools import wraps from platform import python_version -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -376,3 +376,28 @@ def init_dist_connection( f"All distributed processes registered. Starting with {world_size} processes\n" f"{'-' * 100}\n" ) + + +def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]: + """This distributed utility collects dictionary state across all processes. + + Args: + state: Dictionary containing the state of the current process + device: Current process device. + + Returns: + states: On global rank 0, a dictionary where the primary keys are + the process rank and the values their associated states. Otherwise, returns None. + """ + if not distributed_available(): + return {0: state} + states = {} + current_rank = torch.distributed.get_rank() + for rank in range(1, torch.distributed.get_world_size()): + objects = [state if current_rank == rank else None] + torch.distributed.broadcast_object_list(objects, src=rank, device=device) + states[rank] = objects[0] + if current_rank != 0: + return None + states[0] = state + return states diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index e27b4264df126..a48b4486a470f 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -16,6 +16,12 @@ from unittest import mock import pytest +import torch +import torch.multiprocessing as mp + +import tests.helpers.utils as tutils +from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero +from tests.helpers.runif import RunIf @pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}]) @@ -53,3 +59,27 @@ def foo(): x = foo() assert x is None + + +def _test_collect_states(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + + # initialize the process group + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + state = {"something": torch.tensor([rank])} + collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}")) + if rank == 0: + assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} + else: + assert collected_state is None + + +@RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") +def test_collect_states(): + """This test ensures state are properly collected across processes. + + This would be used to collect dataloader states as an example. + """ + tutils.set_random_main_port() + mp.spawn(_test_collect_states, args=(2,), nprocs=2) From dca1776870e9dd7236c03d7ccd55fcb09a9bb5e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Nov 2021 16:35:07 +0100 Subject: [PATCH 29/59] LiteDataLoader wrapper improvements (#10297) --- pytorch_lightning/lite/wrappers.py | 41 ++++++++++++------------ tests/lite/test_lite.py | 50 +++++++++++++----------------- 2 files changed, 41 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3cd2f5eb69712..3bd94b5f36b74 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -15,7 +15,7 @@ import inspect from contextlib import contextmanager from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union +from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -110,21 +110,25 @@ def _convert_float_tensor(t: Tensor) -> Tensor: return output -def _wrap_init(f: Callable) -> Callable: - @functools.wraps(f) - def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: - params = dict(inspect.signature(module._old_init).parameters) +def _wrap_init(init: Callable) -> Callable: + """Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of + :class:`~torch.utils.data.DataLoader`.""" + + @functools.wraps(init) + def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: + params = dict(inspect.signature(obj.__init__).parameters) params.pop("args", None) params.pop("kwargs", None) - for init_name, init_arg in chain(zip(params, args), kwargs.items()): - setattr(module, init_name, init_arg) - f(module, *args, **kwargs) + for arg_name, arg_value in chain(zip(params, args), kwargs.items()): + setattr(obj, arg_name, arg_value) + init(obj, *args, **kwargs) return wrapper # https://stackoverflow.com/a/63851681/9201239 def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: + """Returns a list of all classes that inherit directly or indirectly from the given class.""" subclasses = set() def recurse(cl: Type[Any]) -> None: @@ -136,24 +140,17 @@ def recurse(cl: Type[Any]) -> None: return subclasses -def _enable_class(cls: Type[Any]) -> None: - cls._old_init = cls.__init__ - cls.__init__ = _wrap_init(cls.__init__) - - -def _disable_class(cls: Type[Any]) -> None: - cls.__init__ = cls._old_init - del cls._old_init - - @contextmanager -def _replace_dataloader_init_method() -> Generator: - """This context manager is used to support custom :class:`~torch.utils.data.DataLoader.""" +def _replace_dataloader_init_method() -> Generator[None, None, None]: + """This context manager is used to add support for re-instantiation of custom (subclasses) of + :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" for subclass in _get_all_subclasses(DataLoader): - _enable_class(subclass) + subclass._old_init = subclass.__init__ + subclass.__init__ = _wrap_init(subclass.__init__) yield for subclass in _get_all_subclasses(DataLoader): - _disable_class(subclass) + subclass.__init__ = subclass._old_init + del subclass._old_init class _LiteDataLoader: diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 7c79cb7f2e709..f9ed4a9da7d9d 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -24,12 +24,7 @@ from torch.utils.data import DataLoader, DistributedSampler, Sampler from pytorch_lightning.lite import LightningLite -from pytorch_lightning.lite.wrappers import ( - _LiteDataLoader, - _LiteModule, - _LiteOptimizer, - _replace_dataloader_init_method, -) +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.utilities import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -197,6 +192,27 @@ def run(self): LiteWithCustomDataLoader().run() +def test_setup_dataloaders_raises_for_unknown_custom_args(): + """Test that an error raises when custom dataloaders with unknown arguments are created from outside Lite's run + method.""" + lite = EmptyLite() + + class CustomDataLoader(DataLoader): + def __init__(self, new_arg, *args, **kwargs): + super().__init__(range(5), *args, **kwargs) + + with pytest.raises( + MisconfigurationException, + match=( + r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*" + r"The missing attributes are \['new_arg'\]" + ), + ): + # The dataloader was not created within the run function, and therefore init args were not intercepted + dataloader = CustomDataLoader(2, batch_size=2) + lite.setup_dataloaders(dataloader) + + def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" lite = EmptyLite() @@ -444,25 +460,3 @@ def run(self): assert self.is_global_zero == (self.local_rank == 0) Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() - - -def test_replace_dataloader_init_method(): - """Test that the context manager enables to save the parameters passed to the DataLoader __init__ method.""" - - class CustomDataLoader(DataLoader): - def __init__(self, extra_argument: int, *args, **kwargs): - super().__init__(*args, **kwargs) - - dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) - lite = EmptyLite() - with pytest.raises(MisconfigurationException, match="extra_argument"): - dataloader = lite.setup_dataloaders(dataloader) - - with _replace_dataloader_init_method(): - dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) - assert dataloader.extra_argument == 1 - dataloader = lite.setup_dataloaders(dataloader) - - dataloader = CustomDataLoader(1, range(1)) - assert dataloader.extra_argument == 1 - dataloader = lite.setup_dataloaders(dataloader) From b28ab34ff5b68b311dd145abe3f6badf69eb4159 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 23 Nov 2021 17:18:36 +0000 Subject: [PATCH 30/59] Fault Tolerant Manual: Add loading to reload the states (#10699) Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 +- .../loops/epoch/evaluation_epoch_loop.py | 4 +- pytorch_lightning/trainer/supporters.py | 4 +- pytorch_lightning/utilities/auto_restart.py | 98 ++++++++++++++----- tests/utilities/test_auto_restart.py | 10 +- 5 files changed, 87 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index da15315e6a7f2..1f7014a71d9a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) - + * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) - diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 102603f20302b..2fc572ea252e6 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,7 +22,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _update_dataloader_iter from pytorch_lightning.trainer.progress import BatchProgress -from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict +from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -182,7 +182,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: - reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) + _reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) self._dataloader_state_dict = None def _num_completed_batches_reached(self) -> bool: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6e2e51e82bbf1..d65bc08e6689e 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -24,9 +24,9 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( + _reload_dataloader_state_dict, MergedIteratorState, patch_dataloader_iterator, - reload_dataloader_state_dict, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -403,7 +403,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader - reload_dataloader_state_dict(dataloader, state_dict) + _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4e984f7ecb2aa..3fa32bc72da5e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -33,7 +33,6 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training class FastForwardSampler(Sampler): @@ -564,38 +563,90 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None: ) -def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: - """Utility to reload state_dict within dataloader for fault tolerance.""" +def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + iterator_state = state_dict["state"][0] - if not _fault_tolerant_training(): - return + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) - dataset = dataloader.dataset + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) - if isinstance(dataset, CaptureMapDataset): - iterator_state = state_dict["state"][0] + # reload dataset state + dataloader.dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) - if not isinstance(iterator_state, IteratorState): - iterator_state = IteratorState.from_state_dict(iterator_state) - # reload sampler state - ff_sampler = _find_fast_forward_samplers(dataloader) - ff_sampler.load_state_dict(iterator_state.sampler_state) +def _reload_dataloader_state_dict_automatic_iterable_dataset( + dataset: CaptureIterableDataset, state_dict: Dict[str, Any] +) -> None: + dataset.load_state_dict( + {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} + ) - # reload dataset state - dataset.load_state_dict( - iterator_state.dataset_state, - latest_worker_id=state_dict["latest_worker_id"], - num_workers=iterator_state.num_workers, - ) + +def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + dataset = dataloader.dataset + if isinstance(dataset, CaptureMapDataset): + _reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict) elif isinstance(dataset, CaptureIterableDataset): - dataset.load_state_dict( - {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} - ) + _reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict) + + else: + raise MisconfigurationException("This shouldn't be happening. Please, open an issue.") + + +def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` + # therefore, we need to reload the states manually. + + latest_worker_id = state_dict["latest_worker_id"] + num_workers = state_dict["state"][latest_worker_id]["num_workers"] + sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) + if sampler_state: + # `sampler_state` keys contain all the DataLoader attribute names + # which matched `_SupportsStateDict` API interface while collecting the `state_dict`. + for dataloader_attr_name in sampler_state: + obj = getattr(dataloader, dataloader_attr_name) + if not isinstance(obj, _SupportsStateDict): + raise MisconfigurationException( + f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method." + ) + + obj.load_state_dict(sampler_state[dataloader_attr_name]) + + if not isinstance(dataloader.dataset, _SupportsStateDict): + return + + dataset_state = { + worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id] + for worker_id in state_dict["state"].keys() + } + + dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) + + +def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + """Utility to reload state_dict within dataloader for fault tolerance.""" + + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + + if not fault_tolerant_mode.is_enabled: + return + + if fault_tolerant_mode.is_automatic: + _reload_dataloader_state_dict_automatic(dataloader, state_dict) + + elif fault_tolerant_mode.is_manual: + _reload_dataloader_state_dict_manual(dataloader, state_dict) else: - raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + raise MisconfigurationException("This shouldn't be happening. Please, open an issue.") def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: @@ -638,7 +689,6 @@ def _store_sampler_state(self) -> None: for k, v in self._loader.__dict__.items() if isinstance(v, _SupportsStateDict) and k != "dataset" } - self.__accumulate_state(sampler_state) def _next_index(self) -> Any: diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 25f80ec6817a5..1c27d582cc6a5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -19,6 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from copy import deepcopy +from dataclasses import asdict from typing import List, Optional from unittest import mock from unittest.mock import ANY @@ -42,6 +43,7 @@ _dataloader_to_state_dict, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, + _reload_dataloader_state_dict, _rotate_worker_indices, _SingleProcessDataLoaderIterStateful, _SupportsStateDict, @@ -1289,7 +1291,7 @@ def state_dict(self): return {"counter": self.counter} def load_state_dict(self, state_dict): - self.counter = state_dict["counter"] + self.counter = state_dict[0]["counter"] @pytest.mark.parametrize("num_workers", [0]) @@ -1319,7 +1321,9 @@ def test_stateful_workers(num_workers): assert isinstance(dataloader_iter, worker_type) next(data_fetcher_iter) - state = data_fetcher.dataloader_iter.state.state + + reloaded_state = deepcopy(data_fetcher.dataloader_iter.state) + state = reloaded_state.state assert state[0].dataset_state == {0: {"counter": 1}} assert state[0].sampler_state["sampler"] == {"counter": 1} @@ -1350,4 +1354,6 @@ def test_stateful_workers(num_workers): assert not hasattr(DataLoader, "_ori_get_iterator") assert DataLoader._get_iterator == _get_iterator_fn + _reload_dataloader_state_dict(dataloader, asdict(reloaded_state)) + assert dataloader.sampler.counter == dataloader.dataset.counter == 1 data_fetcher.teardown() From 89d0064b33a8a8e60177ccca4fc176333941db4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 23 Nov 2021 18:23:36 +0100 Subject: [PATCH 31/59] Use `PrecisionType` enum instead of checking raw values (#10704) * use precision type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../plugins/training_type/deepspeed.py | 16 ++++++++++++---- .../plugins/training_type/fully_sharded.py | 4 ++-- pytorch_lightning/plugins/training_type/ipu.py | 3 ++- .../plugins/training_type/sharded.py | 4 ++-- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 86d380ac24ce8..3a704fc5f1848 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -37,7 +37,7 @@ from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import log, rank_zero_info -from pytorch_lightning.utilities.enums import _StrategyType, AMPType +from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden @@ -445,7 +445,11 @@ def init_deepspeed(self): if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device - dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 + dtype = ( + torch.float16 + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED) + else torch.float32 + ) deepspeed.zero.Init( module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -502,7 +506,11 @@ def _initialize_deepspeed_train(self, model): def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 + dtype = ( + torch.float16 + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED) + else torch.float32 + ) model_parallel_context = deepspeed.zero.Init( remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -629,7 +637,7 @@ def _auto_select_batch_size(self): return batch_size def _format_precision_config(self) -> None: - if self.precision_plugin.precision in (16, "mixed"): + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 73ea87b05835e..38fa2942a7819 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -21,7 +21,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE -from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: @@ -139,7 +139,7 @@ def wrap_policy(*args, **kwargs): cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, - mixed_precision=precision == "mixed", + mixed_precision=(precision == PrecisionType.MIXED), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index ef9b3d1f02b82..8f8f082280156 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _POPTORCH_AVAILABLE: @@ -41,7 +42,7 @@ def __init__(self, pl_module: "pl.LightningModule", precision: Union[str, int]): self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: - if self.precision in ("mixed", 16): + if self.precision in (PrecisionType.MIXED, PrecisionType.HALF): inputs = self._move_float_tensors_to_half(inputs) return super().forward(*inputs, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index c9627324eb237..e7f57e9c92791 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -23,7 +23,7 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only -from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: @@ -71,7 +71,7 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - is_fp16 = self.precision_plugin.precision in ("mixed", 16) + is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance. From f36b395c4eb4c8d113954d768444194b3729be28 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 24 Nov 2021 17:01:03 +0530 Subject: [PATCH 32/59] Update `LightningDataModule` docs (#10678) --- docs/source/common/lightning_module.rst | 16 +- docs/source/common/trainer.rst | 14 +- docs/source/extensions/datamodules.rst | 200 ++++++++++-------- pytorch_lightning/core/hooks.py | 20 +- .../trainer/connectors/data_connector.py | 5 +- pytorch_lightning/trainer/trainer.py | 3 +- tests/deprecated_api/test_remove_1-7.py | 4 +- 7 files changed, 153 insertions(+), 109 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index ca6950fe36e0b..7a9bd95f12233 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1031,7 +1031,21 @@ use_amp ~~~~~~~ ``True`` if using Automatic Mixed Precision (AMP) --------------- +------------ + +prepare_data_per_node +~~~~~~~~~~~~~~~~~~~~~ +If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node. +If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0. + +.. testcode:: + + class LitModel(LightningModule): + def __init__(self): + super().__init__() + self.prepare_data_per_node = True + +------------ automatic_optimization ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 7e336c758f621..e79b5e90484d0 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1131,6 +1131,16 @@ To define your own behavior, subclass the relevant class and pass it in. Here's prepare_data_per_node ^^^^^^^^^^^^^^^^^^^^^ +.. warning:: ``prepare_data_per_node`` has been deprecated in v1.5 and will be removed in v1.7. + Please set its value inside ``LightningDataModule`` and/or ``LightningModule`` directly described + in the following code: + + .. testcode:: + + class LitDataModule(LightningDataModule): + def __init__(self): + super().__init__() + self.prepare_data_per_node = True .. raw:: html @@ -1140,8 +1150,8 @@ prepare_data_per_node | -If True will call `prepare_data()` on LOCAL_RANK=0 for every node. -If False will only call from NODE_RANK=0, LOCAL_RANK=0 +If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node. +If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0. .. testcode:: diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 7b8e2cd3754e6..6dc3398dab31d 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -30,10 +30,10 @@ This class can then be shared and used anywhere: trainer = Trainer() imagenet = ImagenetDataModule() - trainer.fit(model, imagenet) + trainer.fit(model, datamodule=imagenet) cifar10 = CIFAR10DataModule() - trainer.fit(model, cifar10) + trainer.fit(model, datamodule=cifar10) --------------- @@ -53,8 +53,8 @@ Datamodules are for you if you ever asked the questions: What is a DataModule -------------------- -A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the -matching transforms and data processing/downloads steps required. +A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) and +predict_dataloader(s) along with the matching transforms and data processing/downloads steps required. Here's a simple PyTorch example: @@ -62,12 +62,14 @@ Here's a simple PyTorch example: # regular PyTorch test_data = MNIST(my_path, train=False, download=True) + predict_data = MNIST(my_path, train=False, download=True) train_data = MNIST(my_path, train=True, download=True) train_data, val_data = random_split(train_data, [55000, 5000]) train_loader = DataLoader(train_data, batch_size=32) val_loader = DataLoader(val_data, batch_size=32) test_loader = DataLoader(test_data, batch_size=32) + predict_loader = DataLoader(predict_data, batch_size=32) The equivalent DataModule just organizes the same exact code, but makes it reusable across projects. @@ -81,6 +83,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def setup(self, stage: Optional[str] = None): self.mnist_test = MNIST(self.data_dir, train=False) + self.mnist_predict = MNIST(self.data_dir, train=False) mnist_full = MNIST(self.data_dir, train=True) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) @@ -93,6 +96,9 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def predict_dataloader(self): + return DataLoader(self.mnist_predict, batch_size=self.batch_size) + def teardown(self, stage: Optional[str] = None): # Used to clean-up when the run is finished ... @@ -127,10 +133,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th self.data_dir = data_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) - # Setting default dims here because we know them. - # Could optionally be assigned dynamically in dm.setup() - self.dims = (1, 28, 28) - def prepare_data(self): # download MNIST(self.data_dir, train=True, download=True) @@ -143,15 +145,12 @@ Here's a more realistic, complex DataModule that shows how much more reusable th mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) - # Optionally... - # self.dims = tuple(self.mnist_train[0][0].shape) - # Assign test dataset for use in dataloader(s) if stage == "test" or stage is None: self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) - # Optionally... - # self.dims = tuple(self.mnist_test[0][0].shape) + if stage == "predict" or stage is None: + self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) @@ -162,25 +161,29 @@ Here's a more realistic, complex DataModule that shows how much more reusable th def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) + def predict_dataloader(self): + return DataLoader(self.mnist_predict, batch_size=32) + --------------- LightningDataModule API ----------------------- -To define a DataModule define 5 methods: +To define a DataModule the following methods are used to create train/val/test/predict dataloaders: -- prepare_data (how to download(), tokenize, etc...) -- setup (how to split, etc...) -- train_dataloader -- val_dataloader(s) -- test_dataloader(s) - -and optionally one or multiple predict_dataloader(s). +- :ref:`prepare_data` (how to download, tokenize, etc...) +- :ref:`setup` (how to split, define dataset, etc...) +- :ref:`train_dataloader` +- :ref:`val_dataloader` +- :ref:`test_dataloader` +- :ref:`predict_dataloader` prepare_data -^^^^^^^^^^^^ -Use this method to do things that might write to disk or that need to be done only from a single process in distributed -settings. +~~~~~~~~~~~~ +Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning +ensures the :meth:`~pytorch_lightning.core.hooks.DataHooks.prepare_data` is called only within a single process, +so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook +depends upon :ref:`prepare_data_per_node`. - download - tokenize @@ -195,16 +198,17 @@ settings. MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) -.. warning:: ``prepare_data`` is called from a single process (e.g. GPU 0). Do not use it to assign state (`self.x = y`). +.. warning:: ``prepare_data`` is called from the main process. It is not recommended to assign state here (e.g. ``self.x = y``). setup -^^^^^ -There are also data operations you might want to perform on every GPU. Use setup to do things like: +~~~~~ +There are also data operations you might want to perform on every GPU. Use :meth:`~pytorch_lightning.core.hooks.DataHooks.setup` to do things like: - count number of classes - build vocabulary - perform train/val/test splits +- create datasets - apply transforms (defined explicitly in your datamodule) - etc... @@ -220,25 +224,25 @@ There are also data operations you might want to perform on every GPU. Use setup if stage in (None, "fit"): mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) - self.dims = self.mnist_train[0][0].shape # Assign Test split(s) for use in Dataloaders if stage in (None, "test"): self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform) - self.dims = getattr(self, "dims", self.mnist_test[0][0].shape) -:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument. -It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``, +This method expects a ``stage`` argument. +It is used to separate setup logic for ``trainer.{fit,validate,test,predict}``. If ``setup`` is called with ``stage=None``, we assume all stages have been set-up. -.. note:: ``setup`` is called from every process. Setting state here is okay. -.. note:: ``teardown`` can be used to clean up the state. It is also called from every process +.. note:: :ref:`setup` is called from every process across all the nodes. Setting state here is recommended. +.. note:: :ref:`teardown` can be used to clean up the state. It is also called from every process across all the nodes. train_dataloader -^^^^^^^^^^^^^^^^ -Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in ``setup``. +~~~~~~~~~~~~~~~~ +Use the :meth:`~pytorch_lightning.core.hooks.DataHooks.train_dataloader` method to generate the training dataloader(s). +Usually you just wrap the dataset you defined in :ref:`setup`. This is the dataloader that the Trainer +:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` method uses. .. code-block:: python @@ -251,8 +255,10 @@ Use this method to generate the train dataloader. Usually you just wrap the data val_dataloader -^^^^^^^^^^^^^^ -Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in ``setup``. +~~~~~~~~~~~~~~ +Use the :meth:`~pytorch_lightning.core.hooks.DataHooks.val_dataloader` method to generate the validation dataloader(s). +Usually you just wrap the dataset you defined in :ref:`setup`. This is the dataloader that the Trainer +:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate` methods uses. .. code-block:: python @@ -267,8 +273,10 @@ Use this method to generate the val dataloader. Usually you just wrap the datase .. _datamodule-test-dataloader-label: test_dataloader -^^^^^^^^^^^^^^^ -Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``. +~~~~~~~~~~~~~~~ +Use the :meth:`~pytorch_lightning.core.hooks.DataHooks.test_dataloader` method to generate the test dataloader(s). +Usually you just wrap the dataset you defined in :ref:`setup`. This is the dataloader that the Trainer +:meth:`~pytorch_lightning.trainer.trainer.Trainer.test` method uses. .. code-block:: python @@ -281,8 +289,9 @@ Use this method to generate the test dataloader. Usually you just wrap the datas predict_dataloader -^^^^^^^^^^^^^^^^^^ -Returns a special dataloader for inference. This is the dataloader that the Trainer +~~~~~~~~~~~~~~~~~~ +Use the :meth:`~pytorch_lightning.core.hooks.DataHooks.predict_dataloader` method to generate the prediction dataloader(s). +Usually you just wrap the dataset you defined in :ref:`setup`. This is the dataloader that the Trainer :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict` method uses. .. code-block:: python @@ -292,76 +301,80 @@ Returns a special dataloader for inference. This is the dataloader that the Trai class MNISTDataModule(pl.LightningDataModule): def predict_dataloader(self): - return DataLoader(self.mnist_test, batch_size=64) + return DataLoader(self.mnist_predict, batch_size=64) transfer_batch_to_device -^^^^^^^^^^^^^^^^^^^^^^^^ -Override to define how you want to move an arbitrary batch to a device. -To check the current state of execution of this hook you can use ``self.trainer.training/testing/validating/predicting/sanity_checking`` -so that you can add different logic as per your requirement. +~~~~~~~~~~~~~~~~~~~~~~~~ -.. testcode:: +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.transfer_batch_to_device + :noindex: + +on_before_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~~ - class MNISTDataModule(LightningDataModule): - def transfer_batch_to_device(self, batch, device, dataloader_idx): - x = batch["x"] - x = CustomDataWrapper(x) - batch["x"] = x.to(device) - return batch +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_before_batch_transfer + :noindex: +on_after_batch_transfer +~~~~~~~~~~~~~~~~~~~~~~~ -.. note:: This hook only runs on single GPU training and DDP (no data-parallel). +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_after_batch_transfer + :noindex: +on_load_checkpoint +~~~~~~~~~~~~~~~~~~ -on_before_batch_transfer -^^^^^^^^^^^^^^^^^^^^^^^^ -Override to alter or apply augmentations to your batch before it is transferred to the device. -To check the current state of execution of this hook you can use ``self.trainer.training/testing/validating/predicting/sanity_checking`` -so that you can add different logic as per your requirement. +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_load_checkpoint + :noindex: -.. testcode:: +on_save_checkpoint +~~~~~~~~~~~~~~~~~~ - class MNISTDataModule(LightningDataModule): - def on_before_batch_transfer(self, batch, dataloader_idx): - batch["x"] = transforms(batch["x"]) - return batch +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_save_checkpoint + :noindex: +on_train_dataloader +~~~~~~~~~~~~~~~~~~~ -.. note:: This hook only runs on single GPU training and DDP (no data-parallel). +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_train_dataloader + :noindex: +on_val_dataloader +~~~~~~~~~~~~~~~~~ -on_after_batch_transfer -^^^^^^^^^^^^^^^^^^^^^^^ -Override to alter or apply augmentations to your batch after it is transferred to the device. -To check the current state of execution of this hook you can use ``self.trainer.training/testing/validating/predicting/sanity_checking`` -so that you can add different logic as per your requirement. +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_val_dataloader + :noindex: -.. testcode:: +on_test_dataloader +~~~~~~~~~~~~~~~~~~ - class MNISTDataModule(LightningDataModule): - def on_after_batch_transfer(self, batch, dataloader_idx): - batch["x"] = gpu_transforms(batch["x"]) - return batch +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_test_dataloader + :noindex: +on_predict_dataloader +~~~~~~~~~~~~~~~~~~~~~ -.. note:: - This hook only runs on single GPU training and DDP (no data-parallel). This hook - will also be called when using CPU device, so adding augmentations here or in - ``on_before_batch_transfer`` means the same thing. +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_predict_dataloader + :noindex: +teardown +~~~~~~~~ +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.teardown + :noindex: -.. note:: To decouple your data from transforms you can parametrize them via ``__init__``. +prepare_data_per_node +~~~~~~~~~~~~~~~~~~~~~ +If set to ``True`` will call ``prepare_data()`` on LOCAL_RANK=0 for every node. +If set to ``False`` will only call from NODE_RANK=0, LOCAL_RANK=0. -.. code-block:: python +.. testcode:: - class MNISTDataModule(pl.LightningDataModule): - def __init__(self, train_transforms, val_transforms, test_transforms): + class LitDataModule(LightningDataModule): + def __init__(self): super().__init__() - self.train_transforms = train_transforms - self.val_transforms = val_transforms - self.test_transforms = test_transforms + self.prepare_data_per_node = True ------------------ @@ -375,12 +388,14 @@ The recommended way to use a DataModule is simply: dm = MNISTDataModule() model = Model() - trainer.fit(model, dm) + trainer.fit(model, datamodule=dm) trainer.test(datamodule=dm) + trainer.validate(datamodule=dm) + trainer.predict(datamodule=dm) If you need information from the dataset to build your model, then run -:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and -:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures +:ref:`prepare_data` and +:ref:`setup` manually (Lightning ensures the method runs on the correct devices). .. code-block:: python @@ -413,6 +428,7 @@ You can of course use DataModules in plain PyTorch code as well. # use data for batch in dm.train_dataloader(): ... + for batch in dm.val_dataloader(): ... @@ -444,4 +460,8 @@ Like LightningModules, DataModules support hyperparameters with the same API. super().__init__() self.save_hyperparameters() -Refer to `save_hyperparameters` in :doc:`lightning module <../common/lightning_module>` for more details. + def configure_optimizers(self): + # access the saved hyperparameters + opt = optim.Adam(self.parameters(), lr=self.hparams.lr) + +Refer to ``save_hyperparameters`` in :doc:`lightning module <../common/lightning_module>` for more details. diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 5263c16952fec..5dfe3e986e14a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -323,10 +323,12 @@ def __init__(self) -> None: self.allow_zero_length_dataloader_with_multiple_devices: bool = False def prepare_data(self) -> None: - """Use this to download and prepare data. + """Use this to download and prepare data. Downloading and saving data with multiple processes (distributed + settings) will result in corrupted data. Lightning ensures this method is called only within a single + process, so you can safely add your downloading logic within. - .. warning:: DO NOT set state to the model (use `setup` instead) - since this is NOT called on every GPU in DDP/TPU + .. warning:: DO NOT set state to the model (use ``setup`` instead) + since this is NOT called on every device Example:: @@ -340,11 +342,13 @@ def prepare_data(self): self.split = data_split self.some_state = some_other_state() - In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)): + In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)): 1. Once per node. This is the default and is only called on LOCAL_RANK=0. 2. Once in total. Only called on GLOBAL_RANK=0. + See :ref:`prepare_data_per_node`. + Example:: # DEFAULT @@ -354,10 +358,6 @@ def prepare_data(self): # call on GLOBAL_RANK=0 (great for shared file systems) Trainer(prepare_data_per_node=False) - Note: - Setting ``prepare_data_per_node`` with the trainer flag is deprecated and will be removed in v1.7.0. - Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead. - This is called before requesting the dataloaders: .. code-block:: python @@ -371,7 +371,7 @@ def prepare_data(self): """ def setup(self, stage: Optional[str] = None) -> None: - """Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when + """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. @@ -397,7 +397,7 @@ def setup(stage): """ def teardown(self, stage: Optional[str] = None) -> None: - """Called at the end of fit (train + validate), validate, test, predict, or tune. + """Called at the end of fit (train + validate), validate, test, or predict. Args: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index de81060ba1f80..deee64c90fe2e 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -70,8 +70,9 @@ def on_trainer_init( if prepare_data_per_node is not None: rank_zero_deprecation( - "Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0! " - "Please set `prepare_data_per_node` in LightningDataModule or LightningModule directly instead. " + "Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0 and will be removed in" + " v1.7.0. Please set `prepare_data_per_node` in `LightningDataModule` and/or `LightningModule`" + " directly instead." ) self.trainer.prepare_data_per_node = prepare_data_per_node diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1ccdb9ecaeca8..667b57fd1c76f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -284,7 +284,8 @@ def __init__( .. deprecated:: v1.5 Deprecated in v1.5.0 and will be removed in v1.7.0 - Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead. + Please set ``prepare_data_per_node`` in ``LightningDataModule`` and/or + ``LightningModule`` directly instead. process_position: Orders the progress bar when running multiple models on same machine. diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 9c0c12f981f4b..0065d5947dc26 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -120,9 +120,7 @@ def get_progress_bar_dict(self): def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): - with pytest.deprecated_call( - match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!" - ): + with pytest.deprecated_call(match="Setting `prepare_data_per_node` with the trainer flag is deprecated in v1.5.0"): _ = Trainer(prepare_data_per_node=False) From e51a8ee7a3c8d43aa26fd5c79532a1572b0fac85 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Nov 2021 14:01:55 +0000 Subject: [PATCH 33/59] Fault Tolerant Manual: utilities cleanup (#10703) Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 1 + pytorch_lightning/utilities/auto_restart.py | 137 +++++++------------- pytorch_lightning/utilities/data.py | 32 +++-- tests/utilities/test_auto_restart.py | 40 ------ 4 files changed, 68 insertions(+), 142 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7014a71d9a3..740cf5c9c9c11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) + * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) - diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 3fa32bc72da5e..074090f10e3fe 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -394,52 +394,6 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str return iter_dataloader -def _dataloader_to_state_dict( - dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None -) -> List[Dict[str, Any]]: - """Convert a dataloader to its associated state dict.""" - out = {} - if iterator is not None: - out.update(_find_current_worker(iterator)) - - if not isinstance(dataloader.dataset, CaptureIterableDataset): - fast_forward_sampler = _find_fast_forward_samplers(dataloader) - if fast_forward_sampler is not None: - out.update(fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed)) - return out - - -def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader: - """Reload ``DataLoader`` fast-forward sampler state dict.""" - fast_forward_sampler = _find_fast_forward_samplers(dataloader) - - if isinstance(fast_forward_sampler, Sampler): - state_dict = {k: v for k, v in state_dict.items() if k not in ("num_workers", "previous_worker")} - fast_forward_sampler.load_state_dict(state_dict) - - return dataloader - - -def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: - """Find the current DataLoader Iterator worker if multiple workers were used.""" - # get the current number of workers - num_workers = getattr(iterator, "_num_workers", 0) - if isinstance(iterator, _MultiProcessingDataLoaderIter): - # fetch next worker - next_worker = (next(iterator._worker_queue_idx_cycle)) % num_workers - # get the current worker from next one - previous_worker = (next_worker - 1) % num_workers - # reset back the `worker_queue_idx` to current one, so we can keep - # going without perturbation. - while next(iterator._worker_queue_idx_cycle) != previous_worker: - pass - else: - previous_worker = None - - # return the captured metadata. - return {"num_workers": num_workers, "previous_worker": previous_worker} - - def _capture_metadata_collate( samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode ) -> Any: @@ -476,6 +430,52 @@ def _capture_metadata_collate( return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} +# TODO: Merge this code within stateful DataLoaderIter. +def _next_data_wrapper( + fn: Callable, + it: Iterator, + dl: DataLoader, + num_batches_fetched: int, + data_fetcher: "pl.utilities.fetching.AbstractDataFetcher", +) -> Callable: + @wraps(fn) + def wrapper() -> Any: + nonlocal num_batches_fetched + + dataset = dl.dataset + combined_batch = fn() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] + num_batches_fetched += 1 + + if isinstance(dataset, CaptureIterableDataset): + state = [ + IteratorState( + num_workers=dl.num_workers, + sampler_state=iterator_state, + num_batches_fetched=num_batches_fetched, + worker_id=list(iterator_state.keys())[0], + name=sampler_iter_name, + ) + for sampler_iter_name, iterator_state in state.items() + ] + elif isinstance(dataset, CaptureMapDataset): + ff_sampler = _find_fast_forward_samplers(dl) + state = [ + IteratorState( + num_workers=dl.num_workers, + sampler_state=ff_sampler.state_dict(num_batches_fetched), + dataset_state=state, + worker_id=list(state.keys())[0], + num_batches_fetched=num_batches_fetched, + ) + ] + data_fetcher._store_dataloader_iter_state(it, state) + return batch + + return wrapper + + def patch_dataloader_iterator( dataloader: DataLoader, iterator: Iterator, @@ -506,48 +506,9 @@ def patch_dataloader_iterator( return assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) - - def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: - @wraps(fn) - def wrapper(): - nonlocal num_batches_fetched - nonlocal it - nonlocal dl - - dataset = dl.dataset - combined_batch = fn() - - batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] - num_batches_fetched += 1 - - if isinstance(dataset, CaptureIterableDataset): - state = [ - IteratorState( - num_workers=dataloader.num_workers, - sampler_state=iterator_state, - num_batches_fetched=num_batches_fetched, - worker_id=list(iterator_state.keys())[0], - name=sampler_iter_name, - ) - for sampler_iter_name, iterator_state in state.items() - ] - elif isinstance(dataset, CaptureMapDataset): - ff_sampler = _find_fast_forward_samplers(dl) - state = [ - IteratorState( - num_workers=dataloader.num_workers, - sampler_state=ff_sampler.state_dict(num_batches_fetched), - dataset_state=state, - worker_id=list(state.keys())[0], - num_batches_fetched=num_batches_fetched, - ) - ] - data_fetcher._store_dataloader_iter_state(it, state) - return batch - - return wrapper - - iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched) + iterator._next_data = _next_data_wrapper( + iterator._next_data, iterator, dataloader, num_batches_fetched, data_fetcher + ) def _add_capture_metadata_collate(dataloader: DataLoader) -> None: diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5b56940460ca4..9f725c37d3f23 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -24,8 +24,8 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler +from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.seed import pl_worker_init_function from pytorch_lightning.utilities.warnings import WarningCache @@ -246,17 +246,8 @@ def _get_dataloader_init_kwargs( dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None - if _fault_tolerant_training(): - dataset = dl_kwargs["dataset"] - if isinstance(dataset, IterableDataset): - # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. - dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) - elif get_len(dataset) != float("inf"): - dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) - else: - raise MisconfigurationException( - "This shouldn't happen, please open an issue on Lightning Github repository." - ) + if _FaultTolerantMode.detect_current_mode().is_automatic: + dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs) return dl_kwargs @@ -271,6 +262,7 @@ def _dataloader_init_kwargs_resolve_sampler( Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. """ + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. @@ -283,7 +275,7 @@ def _dataloader_init_kwargs_resolve_sampler( if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - if _fault_tolerant_training(): + if fault_tolerant_mode.is_automatic: fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler) fast_forward_sampler.setup(dataloader_batch_size=1) @@ -295,7 +287,7 @@ def _dataloader_init_kwargs_resolve_sampler( "drop_last": False, } - if _fault_tolerant_training(): + if fault_tolerant_mode.is_automatic: fast_forward_sampler = sampler = FastForwardSampler(sampler) fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size) @@ -305,3 +297,15 @@ def _dataloader_init_kwargs_resolve_sampler( def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) + + +def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict: + dataset = dl_kwargs["dataset"] + if isinstance(dataset, IterableDataset): + # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. + dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset) + elif get_len(dataset) != float("inf"): + dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset) + else: + raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.") + return dl_kwargs diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 1c27d582cc6a5..47f5deb344d91 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,8 +39,6 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, - _dataloader_load_state_dict, - _dataloader_to_state_dict, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, _reload_dataloader_state_dict, @@ -665,44 +663,6 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset -def test_dataloader_to_state_dict_and_reload(): - """ - Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset. - """ - - def create_dataloader(): - dataset = range(50) - batch_size = 8 - sampler = FastForwardSampler(SequentialSampler(dataset)) - sampler.setup(batch_size) - - return DataLoader(dataset, sampler=sampler, batch_size=batch_size) - - dataloader = create_dataloader() - iter_dataloader = iter(dataloader) - _ = next(iter_dataloader) - _ = next(iter_dataloader) - - state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == { - "num_workers": 0, - "previous_worker": None, - 0: {"current_iteration": 16}, - } - - dataloader = create_dataloader() - dataloader = _dataloader_load_state_dict(dataloader, state_dict) - iter_dataloader = iter(dataloader) - _ = next(iter_dataloader) - - state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == { - "num_workers": 0, - "previous_worker": None, - 0: {"current_iteration": 24}, - } - - @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.""" From 30ec4815cb0612689b30ab73988558f1aa1b6f76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 24 Nov 2021 15:58:51 +0100 Subject: [PATCH 34/59] Support re-instantiation for custom DataLoader in Lightning (#10680) Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- CHANGELOG.md | 2 +- pytorch_lightning/lite/lite.py | 14 +++---- pytorch_lightning/lite/wrappers.py | 49 +---------------------- pytorch_lightning/trainer/data_loading.py | 6 ++- pytorch_lightning/utilities/data.py | 49 ++++++++++++++++++++++- tests/lite/test_lite.py | 34 +++++----------- tests/trainer/test_data_loading.py | 29 +++++--------- tests/utilities/test_data.py | 27 +++++++++++++ 8 files changed, 107 insertions(+), 103 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 740cf5c9c9c11..460834dba7a6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) -- +- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) - diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 6c41c80a56171..b2adeeac4bd5b 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -25,17 +25,17 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.lite.wrappers import ( - _LiteDataLoader, - _LiteModule, - _LiteOptimizer, - _replace_dataloader_init_method, -) +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors -from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, _update_dataloader, has_iterable_dataset +from pytorch_lightning.utilities.data import ( + _auto_add_worker_init_fn, + _replace_dataloader_init_method, + _update_dataloader, + has_iterable_dataset, +) from pytorch_lightning.utilities.device_parser import _parse_devices from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3bd94b5f36b74..908ba06bdb84d 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,11 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import inspect -from contextlib import contextmanager -from itertools import chain -from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union +from typing import Any, Callable, Generator, Iterator, Optional, Union import torch from torch import nn as nn @@ -110,49 +106,6 @@ def _convert_float_tensor(t: Tensor) -> Tensor: return output -def _wrap_init(init: Callable) -> Callable: - """Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of - :class:`~torch.utils.data.DataLoader`.""" - - @functools.wraps(init) - def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: - params = dict(inspect.signature(obj.__init__).parameters) - params.pop("args", None) - params.pop("kwargs", None) - for arg_name, arg_value in chain(zip(params, args), kwargs.items()): - setattr(obj, arg_name, arg_value) - init(obj, *args, **kwargs) - - return wrapper - - -# https://stackoverflow.com/a/63851681/9201239 -def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: - """Returns a list of all classes that inherit directly or indirectly from the given class.""" - subclasses = set() - - def recurse(cl: Type[Any]) -> None: - for subclass in cl.__subclasses__(): - subclasses.add(subclass) - recurse(subclass) - - recurse(cls) - return subclasses - - -@contextmanager -def _replace_dataloader_init_method() -> Generator[None, None, None]: - """This context manager is used to add support for re-instantiation of custom (subclasses) of - :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" - for subclass in _get_all_subclasses(DataLoader): - subclass._old_init = subclass.__init__ - subclass.__init__ = _wrap_init(subclass.__init__) - yield - for subclass in _get_all_subclasses(DataLoader): - subclass.__init__ = subclass._old_init - del subclass._old_init - - class _LiteDataLoader: def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: """The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 6044f1320286c..833d0acc4a92e 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -31,6 +31,7 @@ from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, + _replace_dataloader_init_method, _update_dataloader, has_iterable_dataset, has_len_all_ranks, @@ -430,7 +431,10 @@ def request_dataloader( hook = f"{stage.dataloader_prefix}_dataloader" self.call_hook("on_" + hook, pl_module=model) - dataloader = source.dataloader() + with _replace_dataloader_init_method(): + # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as + # attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning + dataloader = source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 9f725c37d3f23..78d75d9972d3c 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect import os +from contextlib import contextmanager from functools import partial -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Union +from itertools import chain +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union import torch from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler @@ -299,6 +302,50 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) +def _wrap_init(init: Callable) -> Callable: + """Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of + :class:`~torch.utils.data.DataLoader`.""" + + @functools.wraps(init) + def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: + params = dict(inspect.signature(obj.__init__).parameters) + params.pop("args", None) + params.pop("kwargs", None) + for arg_name, arg_value in chain(zip(params, args), kwargs.items()): + setattr(obj, arg_name, arg_value) + init(obj, *args, **kwargs) + + return wrapper + + +# https://stackoverflow.com/a/63851681/9201239 +def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: + """Returns a list of all classes that inherit directly or indirectly from the given class.""" + subclasses = set() + + def recurse(cl: Type[Any]) -> None: + for subclass in cl.__subclasses__(): + subclasses.add(subclass) + recurse(subclass) + + recurse(cls) + return subclasses + + +@contextmanager +def _replace_dataloader_init_method() -> Generator[None, None, None]: + """This context manager is used to add support for re-instantiation of custom (subclasses) of + :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" + subclasses = _get_all_subclasses(DataLoader) + for subclass in subclasses: + subclass._old_init = subclass.__init__ + subclass.__init__ = _wrap_init(subclass.__init__) + yield + for subclass in subclasses: + subclass.__init__ = subclass._old_init + del subclass._old_init + + def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict: dataset = dl_kwargs["dataset"] if isinstance(dataset, IterableDataset): diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index f9ed4a9da7d9d..663001d08df54 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -164,32 +164,16 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 -def test_setup_dataloaders_with_custom_type(): - """Test that Lite intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as - attributes.""" - - class DataLoaderSubclass1(DataLoader): - def __init__(self, attribute1, *args, **kwargs): - # intentionally not setting this attribute, calling super with different args - # self.attribute1 = attribute1 - super().__init__(*args, **kwargs) - - class DataLoaderSubclass2(DataLoaderSubclass1): - def __init__(self, attribute1, attribute2, *args, **kwargs): - # intentionally not setting this attribute, calling super with different args - # self.attribute2 = attribute2 - super().__init__(attribute1, *args, **kwargs) - - class LiteWithCustomDataLoader(LightningLite): +@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method") +def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): + """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" + + class Lite(LightningLite): def run(self): - dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2) - assert dataloader.attribute1 == "attribute1" - assert dataloader.attribute2 == "attribute2" - lite_dataloader = self.setup_dataloaders(dataloader) - assert lite_dataloader.attribute1 == "attribute1" - assert lite_dataloader.attribute2 == "attribute2" - - LiteWithCustomDataLoader().run() + ctx_manager().__enter__.assert_called_once() + + Lite().run() + ctx_manager().__exit__.assert_called_once() def test_setup_dataloaders_raises_for_unknown_custom_args(): diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 52dfcb76e4f20..139334cbe6f06 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -28,7 +28,7 @@ @RunIf(skip_windows=True) -@pytest.mark.parametrize("mode", (1, 2, 3)) +@pytest.mark.parametrize("mode", (1, 2)) def test_replace_distributed_sampler(tmpdir, mode): class IndexedRandomDataset(RandomDataset): def __getitem__(self, index): @@ -36,11 +36,8 @@ def __getitem__(self, index): class CustomDataLoader(DataLoader): def __init__(self, num_features, dataset, *args, **kwargs): - self.num_features = num_features - super().__init__(dataset, *args, **kwargs) - - class FailureCustomDataLoader(DataLoader): - def __init__(self, num_features, dataset, *args, **kwargs): + # argument `num_features` unused on purpose + # it gets automatically captured by _replace_dataloader_init_method() super().__init__(dataset, *args, **kwargs) class CustomBatchSampler(BatchSampler): @@ -59,11 +56,11 @@ def on_test_start(self) -> None: dataloader = self.trainer.test_dataloaders[0] assert isinstance(dataloader, CustomDataLoader) batch_sampler = dataloader.batch_sampler - if self._mode == 2: + if self._mode == 1: assert isinstance(batch_sampler, CustomBatchSampler) # the batch_size is set on the batch sampler assert dataloader.batch_size is None - elif self._mode == 3: + elif self._mode == 2: assert type(batch_sampler) is BatchSampler assert dataloader.batch_size == self._mode assert batch_sampler.batch_size == self._mode @@ -74,15 +71,12 @@ def on_test_start(self) -> None: def create_dataset(self): dataset = IndexedRandomDataset(32, 64) if self._mode == 1: - # this case will raise an error - return FailureCustomDataLoader(32, dataset) - if self._mode == 2: # with a custom batch sampler - batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=True) + batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=1, drop_last=True) return CustomDataLoader(32, dataset, batch_sampler=batch_sampler) - elif self._mode == 3: + elif self._mode == 2: # with no batch sampler provided - return CustomDataLoader(32, dataset, batch_size=3, drop_last=True) + return CustomDataLoader(32, dataset, batch_size=2, drop_last=True) def test_dataloader(self): return [self.create_dataset()] * self._numbers_test_dataloaders @@ -93,12 +87,7 @@ def test_dataloader(self): trainer = Trainer( default_root_dir=tmpdir, limit_test_batches=2, strategy="ddp_find_unused_parameters_false", num_processes=1 ) - if mode == 1: - match = escape("missing attributes are ['num_features']") - with pytest.raises(MisconfigurationException, match=match): - trainer.test(model) - else: - trainer.test(model) + trainer.test(model) class TestSpawnBoringModel(BoringModel): diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index ae1f8c6505efc..839b370dbf1f5 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -4,6 +4,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.data import ( + _replace_dataloader_init_method, extract_batch_size, get_len, has_iterable_dataset, @@ -112,3 +113,29 @@ def test_has_len_all_rank(): assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model) assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model) + + +def test_replace_dataloader_init_method(): + """Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and + sets them as attributes.""" + + class DataLoaderSubclass1(DataLoader): + def __init__(self, attribute1, *args, **kwargs): + # intentionally not setting this attribute, calling super with different args + # self.attribute1 = attribute1 + super().__init__(*args, **kwargs) + + class DataLoaderSubclass2(DataLoaderSubclass1): + def __init__(self, attribute1, attribute2, *args, **kwargs): + # intentionally not setting this attribute, calling super with different args + # self.attribute2 = attribute2 + super().__init__(attribute1, *args, **kwargs) + + with _replace_dataloader_init_method(): + dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2) + assert dataloader.attribute1 == "attribute1" + + with _replace_dataloader_init_method(): + dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2) + assert dataloader.attribute1 == "attribute1" + assert dataloader.attribute2 == "attribute2" From 0066ff012962ffa920e9bbd750f4edbea05d4505 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Nov 2021 17:36:08 +0000 Subject: [PATCH 35/59] Fault Tolerant Manual: Enable the feature (#10707) --- CHANGELOG.md | 1 + pytorch_lightning/utilities/auto_restart.py | 20 +-- pytorch_lightning/utilities/fetching.py | 12 +- tests/utilities/test_auto_restart.py | 161 +++++++++++++++++++- 4 files changed, 178 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 460834dba7a6f..59d29e18363c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) + * Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707)) - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 074090f10e3fe..9f99634bd15b2 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -251,9 +251,7 @@ def __len__(self) -> int: def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: # as workers aren't available, the ``state_dict``` is cached until workers are made available. - state_dict = deepcopy(state_dict) - state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers) - self._cached_state_dict = state_dict + self._cached_state_dict = _rotate_worker_indices(deepcopy(state_dict), latest_worker_id, num_workers) def state_dict(self) -> Dict[int, Dict[str, Any]]: return {self.worker_id: {"rng_states": collect_rng_states()}} @@ -513,14 +511,17 @@ def patch_dataloader_iterator( def _add_capture_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" - faut_tolerant_mode = _FaultTolerantMode.detect_current_mode() - if not faut_tolerant_mode.is_enabled: + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + collate_fn = dataloader.collate_fn + if not fault_tolerant_mode.is_enabled or ( + isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate + ): return dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, - collate_fn=dataloader.collate_fn, - fault_tolerant_mode=faut_tolerant_mode, + collate_fn=collate_fn, + fault_tolerant_mode=fault_tolerant_mode, ) @@ -658,8 +659,7 @@ def _next_index(self) -> Any: return indexes def _prepare_loader(self, loader): - if not isinstance(loader.collate_fn, partial): - loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn) + _add_capture_metadata_collate(loader) self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 @@ -723,6 +723,8 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": def _patch_dataloader_get_iterators() -> None: """This function is used to replace the DataLoader iterator by their stateful version.""" + if not _FaultTolerantMode.detect_current_mode().is_manual: + return if not hasattr(DataLoader, "_ori_get_iterator"): DataLoader._ori_get_iterator = DataLoader._get_iterator DataLoader._get_iterator = _get_iterator diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f5bb4be032d10..7ac0bfa00c1e7 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -16,7 +16,6 @@ from collections.abc import Iterable, Iterator from contextlib import contextmanager from copy import deepcopy -from functools import partial from typing import Any, Callable, Generator, List, Optional, Tuple import torch @@ -27,6 +26,8 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, + _patch_dataloader_get_iterators, + _teardown_dataloader_get_iterators, IteratorState, MergedIteratorState, patch_dataloader_iterator, @@ -109,11 +110,7 @@ def _add_capture_metadata_collate(dataloader: Iterable) -> None: if isinstance(dataloader, CombinedLoader): dataloader = dataloader.loaders - def add_capture_metadata_collate(dataloader: DataLoader): - if not isinstance(dataloader.collate_fn, partial): - _add_capture_metadata_collate(dataloader) - - apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate) + apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate) def append_batch(self, batch) -> None: self.batches.append(batch) @@ -206,6 +203,8 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: if self.dataloader is None: raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") self.reset() + self._attach_data_fetcher() + _patch_dataloader_get_iterators() self.dataloader_iter = iter(self.dataloader) self._apply_patch() self.prefetching(self.prefetch_batches) @@ -226,6 +225,7 @@ def teardown(self) -> None: if isinstance(self.dataloader, DataLoader): CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) self.dataloader_iter = None + _teardown_dataloader_get_iterators() class DataFetcher(AbstractDataFetcher): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 47f5deb344d91..c69b70b65b13c 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -20,7 +20,7 @@ from contextlib import suppress from copy import deepcopy from dataclasses import asdict -from typing import List, Optional +from typing import Iterator, List, Optional from unittest import mock from unittest.mock import ANY @@ -1317,3 +1317,162 @@ def test_stateful_workers(num_workers): _reload_dataloader_state_dict(dataloader, asdict(reloaded_state)) assert dataloader.sampler.counter == dataloader.dataset.counter == 1 data_fetcher.teardown() + + +class RandomFaultTolerantDataset(RandomGetItemDataset): + def __init__(self, *args, seed: int, **kwargs): + super().__init__(*args, **kwargs) + self.seed = seed + self._cache_state_dict = None + self.generator = None + self.counter_debug = 0 + + @property + def worker_id(self): + info = get_worker_info() + return info.id if info else 0 + + def __getitem__(self, index): + if self._cache_state_dict: + state_dict = self._cache_state_dict[self.worker_id] + self.generator = random.Random() + self.generator.setstate(state_dict["random_state"]) + self._cache_state_dict = None + + if not self.generator: + self.generator = random.Random(self.seed + self.worker_id) + return torch.tensor(index + self.generator.random()) + + def state_dict(self): + return {self.worker_id: {"random_state": self.generator.getstate()}} + + def load_state_dict(self, state_dict): + self._cache_state_dict = state_dict + + +class RandomFaultTolerantSampler(RandomSampler): + def __init__(self, *args, seed: int = 0, generator=None, **kwargs): + generator = torch.Generator().manual_seed(seed) + super().__init__(*args, generator=generator, **kwargs) + self.counter = 0 + self.restarting = False + + def state_dict(self): + return {"random_state": self.state, "counter": self.counter} + + def load_state_dict(self, state_dict): + self.generator.set_state(state_dict.get("random_state")) + self.counter = state_dict["counter"] + self.restarting = True + + def __len__(self): + return len(self.data_source) - self.counter + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + + self.state = self.generator.get_state() + indices = torch.randperm(n, generator=self.generator).tolist() + + if not self.restarting: + self.counter = 0 + else: + indices = indices[self.counter :] + self.restarting = False + + for index in indices: + self.counter += 1 + yield index + + self.counter = 0 + + +@pytest.mark.parametrize( + ["train_dataset_cls", "val_dataset_cls"], + [ + ([RandomFaultTolerantDataset, RandomFaultTolerantDataset], [RandomFaultTolerantDataset]), + ], +) +@pytest.mark.parametrize("val_check_interval", [0.5]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) +def test_fault_tolerant_manual_mode(val_check_interval, train_dataset_cls, val_dataset_cls, tmpdir): + class TestModel(BoringModel): + def __init__(self, should_fail: bool = False): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.should_fail = should_fail + self.batches = [] + + def training_step(self, batch, batch_idx): + if self.should_fail and batch_idx == 7: + raise CustomException + self.batches.append(batch) + losses = [] + for b in batch: + losses.append(super().training_step(b, batch_idx)["loss"]) + return torch.stack(losses).mean() + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + pass + + validation_epoch_end = None + + def _create_dataloader_kwargs(self, dataset_class, dataset_len, seed, num_workers): + dl_kwargs = {} + dl_kwargs["dataset"] = dataset_class(dataset_len, 1, seed=seed) + dl_kwargs["sampler"] = RandomFaultTolerantSampler(dl_kwargs["dataset"], seed=seed) + dl_kwargs["num_workers"] = num_workers + dl_kwargs["batch_size"] = 1 + return dl_kwargs + + def train_dataloader(self): + return [ + DataLoader( + **self._create_dataloader_kwargs( + dataset_class, 10, seed, seed + 1 if val_check_interval == 1.0 else 0 + ) + ) + for seed, dataset_class in enumerate(train_dataset_cls) + ] + + def val_dataloader(self): + return [ + DataLoader(**self._create_dataloader_kwargs(dataset_class, 1, seed, 0)) + for seed, dataset_class in enumerate(val_dataset_cls) + ] + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + seed_everything(42) + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) + trainer.fit(model) + total_batches = model.batches + total_weight = deepcopy(model.layer.weight) + trainer.train_dataloader = None + + seed_everything(42) + model = TestModel(should_fail=True) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) + with suppress(CustomException): + trainer.fit(model) + trainer.train_dataloader = None + failed_batches = model.batches + failed_weight = deepcopy(model.layer.weight) + + checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt") + assert os.path.exists(checkpoint_path) + + seed_everything(42) + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval) + trainer.fit(model, ckpt_path=checkpoint_path) + trainer.train_dataloader = None + restart_batches = model.batches + + torch.testing.assert_allclose(total_batches, failed_batches + restart_batches) + assert not torch.equal(total_weight, failed_weight) + assert torch.equal(total_weight, model.layer.weight) From f8b2d5b128096a32e3b00b0b5fda40dd61babcd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Nov 2021 22:51:11 +0100 Subject: [PATCH 36/59] Improve error message on `TypeError` during `DataLoader` reconstruction (#10719) --- CHANGELOG.md | 2 +- pytorch_lightning/utilities/data.py | 20 ++++++++++++++++- tests/utilities/test_data.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 59d29e18363c7..a87b3b95f0659 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) -- +- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) - diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 78d75d9972d3c..9963bf6c85ffd 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -180,7 +180,25 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader: dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode) dl_cls = type(dataloader) - dataloader = dl_cls(**dl_kwargs) + try: + dataloader = dl_cls(**dl_kwargs) + except TypeError as e: + # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass + # `__init__` arguments map to one `DataLoader.__init__` argument + import re + + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + argument = match.groups()[0] + message = ( + f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument" + f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." + f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + ) + raise MisconfigurationException(message) from e return dataloader diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 839b370dbf1f5..e202941cf0fbb 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -5,6 +5,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.data import ( _replace_dataloader_init_method, + _update_dataloader, extract_batch_size, get_len, has_iterable_dataset, @@ -115,6 +116,38 @@ def test_has_len_all_rank(): assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model) +def test_update_dataloader_typerror_custom_exception(): + class BadImpl(DataLoader): + def __init__(self, foo, *args, **kwargs): + self.foo = foo + # positional conflict with `dataset` + super().__init__(foo, *args, **kwargs) + + dataloader = BadImpl([1, 2, 3]) + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): + _update_dataloader(dataloader, dataloader.sampler) + + class BadImpl2(DataLoader): + def __init__(self, randomize, *args, **kwargs): + self.randomize = randomize + # keyword conflict with `shuffle` + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = BadImpl2(False, []) + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): + _update_dataloader(dataloader, dataloader.sampler) + + class GoodImpl(DataLoader): + def __init__(self, randomize, *args, **kwargs): + # fixed implementation, kwargs are filtered + self.randomize = randomize or kwargs.pop("shuffle", False) + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = GoodImpl(False, []) + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert isinstance(new_dataloader, GoodImpl) + + def test_replace_dataloader_init_method(): """Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as attributes.""" From 85d7c4dce4661ef5c12fcbd0d9e164ffc371e1af Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Wed, 24 Nov 2021 19:19:30 -0500 Subject: [PATCH 37/59] Configure mypy to install dependencies in CI and update pyproject.toml (#10682) * mypy install deps * fix deps * add examples * fix type errors * fix type error * fix * fix * update pyproject.toml --- .github/workflows/code-checks.yml | 4 +-- pyproject.toml | 28 ++++++------------- .../connectors/logger_connector/result.py | 4 +-- .../utilities/parameter_tying.py | 3 +- 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index e99863dc794d4..8cd4206ab61b7 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -14,8 +14,8 @@ jobs: - uses: actions/setup-python@v2 with: python-version: 3.9 - - name: Install mypy + - name: Install dependencies run: | - grep mypy requirements/test.txt | xargs -0 pip install + pip install '.[dev]' pip list - run: mypy --install-types --non-interactive diff --git a/pyproject.toml b/pyproject.toml index c527ffaa856cf..e3c373aee5aeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,41 +36,31 @@ disable_error_code = "attr-defined" # style choices warn_no_return = "False" -# TODO: Fix typing for these modules +# Changes mypy default to ignore all errors [[tool.mypy.overrides]] module = [ - "pytorch_lightning.callbacks.*", - "pytorch_lightning.core.*", - "pytorch_lightning.loggers.*", - "pytorch_lightning.loops.*", - "pytorch_lightning.overrides.*", - "pytorch_lightning.plugins.environments.*", - "pytorch_lightning.plugins.training_type.*", - "pytorch_lightning.profiler.*", - "pytorch_lightning.trainer.*", - "pytorch_lightning.distributed.*", - "pytorch_lightning.tuner.*", - "pytorch_lightning.utilities.*", + "pytorch_lightning.*", ] ignore_errors = "True" +# Override the default for files where we would like to enable type checking +# TODO: Bring more files into this section [[tool.mypy.overrides]] module = [ "pytorch_lightning.callbacks.device_stats_monitor", "pytorch_lightning.callbacks.early_stopping", "pytorch_lightning.callbacks.gpu_stats_monitor", "pytorch_lightning.callbacks.gradient_accumulation_scheduler", - "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_summary", "pytorch_lightning.callbacks.progress", "pytorch_lightning.callbacks.pruning", "pytorch_lightning.callbacks.rich_model_summary", "pytorch_lightning.core.optimizer", - "pytorch_lightning.lite.*", - "pytorch_lightning.loops.optimization.*", + "pytorch_lightning.loops.optimization.closure.py", + "pytorch_lightning.loops.optimization.manual_loop.py", "pytorch_lightning.loops.evaluation_loop", - "pytorch_lightning.trainer.connectors.checkpoint_connector", - "pytorch_lightning.trainer.connectors.logger_connector.*", + "pytorch_lightning.trainer.connectors.logger_connector.py", + "pytorch_lightning.trainer.connectors.logger_connector.fx_validator.py", "pytorch_lightning.trainer.connectors.signal_connector", "pytorch_lightning.trainer.progress.*", "pytorch_lightning.tuner.auto_gpu_select", @@ -80,8 +70,6 @@ module = [ "pytorch_lightning.utilities.cloud_io", "pytorch_lightning.utilities.device_dtype_mixin", "pytorch_lightning.utilities.device_parser", - "pytorch_lightning.utilities.distributed", - "pytorch_lightning.utilities.memory", "pytorch_lightning.utilities.model_summary", "pytorch_lightning.utilities.parameter_tying", "pytorch_lightning.utilities.parsing", diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ab3c0f1804c2a..e10360a5fb564 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -280,8 +280,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: ) # return cached value - if self._computed is not None: # type: ignore - return self._computed # type: ignore + if self._computed is not None: + return self._computed self._computed = compute(*args, **kwargs) return self._computed diff --git a/pytorch_lightning/utilities/parameter_tying.py b/pytorch_lightning/utilities/parameter_tying.py index 7a074deec9d1d..8278c6510cf4a 100644 --- a/pytorch_lightning/utilities/parameter_tying.py +++ b/pytorch_lightning/utilities/parameter_tying.py @@ -19,7 +19,6 @@ from typing import Dict, List, Optional from torch import nn -from torch.nn import Parameter def find_shared_parameters(module: nn.Module) -> List[str]: @@ -64,7 +63,7 @@ def _get_module_by_path(module: nn.Module, path: str) -> nn.Module: return module -def _set_module_by_path(module: nn.Module, path: str, value: Parameter) -> None: +def _set_module_by_path(module: nn.Module, path: str, value: nn.Module) -> None: path = path.split(".") for name in path[:-1]: module = getattr(module, name) From b57feccbff57271144dae70a82b1948f7a4cf7af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 25 Nov 2021 11:27:13 +0100 Subject: [PATCH 38/59] Be explicit with mypy ignores (#10751) * Ignore mypy only for failing files * Comment --- pyproject.toml | 138 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 102 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3c373aee5aeb..c266e0684e974 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,43 +37,109 @@ disable_error_code = "attr-defined" warn_no_return = "False" # Changes mypy default to ignore all errors +# TODO: the goal is for this to be empty [[tool.mypy.overrides]] +# the list can be generated with: +# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",' module = [ - "pytorch_lightning.*", + "pytorch_lightning.accelerators.accelerator", + "pytorch_lightning.accelerators.gpu", + "pytorch_lightning.callbacks.finetuning", + "pytorch_lightning.callbacks.lr_monitor", + "pytorch_lightning.callbacks.model_checkpoint", + "pytorch_lightning.callbacks.prediction_writer", + "pytorch_lightning.callbacks.progress.base", + "pytorch_lightning.callbacks.progress.progress", + "pytorch_lightning.callbacks.progress.rich_progress", + "pytorch_lightning.callbacks.progress.tqdm_progress", + "pytorch_lightning.callbacks.quantization", + "pytorch_lightning.callbacks.stochastic_weight_avg", + "pytorch_lightning.callbacks.timer", + "pytorch_lightning.callbacks.xla_stats_monitor", + "pytorch_lightning.core.datamodule", + "pytorch_lightning.core.decorators", + "pytorch_lightning.core.lightning", + "pytorch_lightning.core.mixins.device_dtype_mixin", + "pytorch_lightning.core.mixins.hparams_mixin", + "pytorch_lightning.core.saving", + "pytorch_lightning.distributed.dist", + "pytorch_lightning.lite.lite", + "pytorch_lightning.lite.wrappers", + "pytorch_lightning.loggers.base", + "pytorch_lightning.loggers.comet", + "pytorch_lightning.loggers.csv_logs", + "pytorch_lightning.loggers.mlflow", + "pytorch_lightning.loggers.neptune", + "pytorch_lightning.loggers.tensorboard", + "pytorch_lightning.loggers.test_tube", + "pytorch_lightning.loggers.wandb", + "pytorch_lightning.loops.base", + "pytorch_lightning.loops.batch.training_batch_loop", + "pytorch_lightning.loops.dataloader.dataloader_loop", + "pytorch_lightning.loops.dataloader.evaluation_loop", + "pytorch_lightning.loops.dataloader.prediction_loop", + "pytorch_lightning.loops.epoch.evaluation_epoch_loop", + "pytorch_lightning.loops.epoch.prediction_epoch_loop", + "pytorch_lightning.loops.epoch.training_epoch_loop", + "pytorch_lightning.loops.fit_loop", + "pytorch_lightning.loops.optimization.optimizer_loop", + "pytorch_lightning.loops.utilities", + "pytorch_lightning.overrides.base", + "pytorch_lightning.overrides.data_parallel", + "pytorch_lightning.overrides.distributed", + "pytorch_lightning.overrides.fairscale", + "pytorch_lightning.plugins.environments.lightning_environment", + "pytorch_lightning.plugins.environments.lsf_environment", + "pytorch_lightning.plugins.environments.slurm_environment", + "pytorch_lightning.plugins.environments.torchelastic_environment", + "pytorch_lightning.plugins.precision.deepspeed", + "pytorch_lightning.plugins.precision.native_amp", + "pytorch_lightning.plugins.precision.precision_plugin", + "pytorch_lightning.plugins.training_type.ddp", + "pytorch_lightning.plugins.training_type.ddp2", + "pytorch_lightning.plugins.training_type.ddp_spawn", + "pytorch_lightning.plugins.training_type.deepspeed", + "pytorch_lightning.plugins.training_type.dp", + "pytorch_lightning.plugins.training_type.fully_sharded", + "pytorch_lightning.plugins.training_type.horovod", + "pytorch_lightning.plugins.training_type.ipu", + "pytorch_lightning.plugins.training_type.parallel", + "pytorch_lightning.plugins.training_type.sharded", + "pytorch_lightning.plugins.training_type.sharded_spawn", + "pytorch_lightning.plugins.training_type.single_device", + "pytorch_lightning.plugins.training_type.single_tpu", + "pytorch_lightning.plugins.training_type.tpu_spawn", + "pytorch_lightning.plugins.training_type.training_type_plugin", + "pytorch_lightning.profiler.advanced", + "pytorch_lightning.profiler.base", + "pytorch_lightning.profiler.pytorch", + "pytorch_lightning.profiler.simple", + "pytorch_lightning.trainer.callback_hook", + "pytorch_lightning.trainer.configuration_validator", + "pytorch_lightning.trainer.connectors.accelerator_connector", + "pytorch_lightning.trainer.connectors.callback_connector", + "pytorch_lightning.trainer.connectors.checkpoint_connector", + "pytorch_lightning.trainer.connectors.data_connector", + "pytorch_lightning.trainer.connectors.logger_connector.result", + "pytorch_lightning.trainer.data_loading", + "pytorch_lightning.trainer.optimizers", + "pytorch_lightning.trainer.supporters", + "pytorch_lightning.trainer.trainer", + "pytorch_lightning.tuner.batch_size_scaling", + "pytorch_lightning.tuner.lr_finder", + "pytorch_lightning.tuner.tuning", + "pytorch_lightning.utilities.auto_restart", + "pytorch_lightning.utilities.data", + "pytorch_lightning.utilities.deepspeed", + "pytorch_lightning.utilities.distributed", + "pytorch_lightning.utilities.enums", + "pytorch_lightning.utilities.fetching", + "pytorch_lightning.utilities.imports", + "pytorch_lightning.utilities.memory", + "pytorch_lightning.utilities.meta", + "pytorch_lightning.utilities.metrics", + "pytorch_lightning.utilities.migration", + "pytorch_lightning.utilities.upgrade_checkpoint", + "pytorch_lightning.utilities.warnings", ] ignore_errors = "True" - -# Override the default for files where we would like to enable type checking -# TODO: Bring more files into this section -[[tool.mypy.overrides]] -module = [ - "pytorch_lightning.callbacks.device_stats_monitor", - "pytorch_lightning.callbacks.early_stopping", - "pytorch_lightning.callbacks.gpu_stats_monitor", - "pytorch_lightning.callbacks.gradient_accumulation_scheduler", - "pytorch_lightning.callbacks.model_summary", - "pytorch_lightning.callbacks.progress", - "pytorch_lightning.callbacks.pruning", - "pytorch_lightning.callbacks.rich_model_summary", - "pytorch_lightning.core.optimizer", - "pytorch_lightning.loops.optimization.closure.py", - "pytorch_lightning.loops.optimization.manual_loop.py", - "pytorch_lightning.loops.evaluation_loop", - "pytorch_lightning.trainer.connectors.logger_connector.py", - "pytorch_lightning.trainer.connectors.logger_connector.fx_validator.py", - "pytorch_lightning.trainer.connectors.signal_connector", - "pytorch_lightning.trainer.progress.*", - "pytorch_lightning.tuner.auto_gpu_select", - "pytorch_lightning.utilities.apply_func", - "pytorch_lightning.utilities.argparse", - "pytorch_lightning.utilities.cli", - "pytorch_lightning.utilities.cloud_io", - "pytorch_lightning.utilities.device_dtype_mixin", - "pytorch_lightning.utilities.device_parser", - "pytorch_lightning.utilities.model_summary", - "pytorch_lightning.utilities.parameter_tying", - "pytorch_lightning.utilities.parsing", - "pytorch_lightning.utilities.seed", - "pytorch_lightning.utilities.xla_device", -] -ignore_errors = "False" From e0b4bb2ea34ea2ade517da2b9a4cdbb7d97e3de0 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 25 Nov 2021 21:11:03 +0530 Subject: [PATCH 39/59] Deprecate `DeviceType` in favor of `_AcceleratorType` (#10503) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 3 + .../callbacks/gpu_stats_monitor.py | 4 +- .../callbacks/xla_stats_monitor.py | 4 +- pytorch_lightning/lite/lite.py | 10 +- .../loops/optimization/optimizer_loop.py | 4 +- .../connectors/accelerator_connector.py | 98 +++++++++---------- .../logger_connector/logger_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 20 ++-- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/enums.py | 40 ++++++-- pytorch_lightning/utilities/model_summary.py | 8 +- .../test_accelerator_connector.py | 4 +- tests/accelerators/test_ipu.py | 4 +- tests/deprecated_api/test_remove_1-8.py | 8 +- tests/models/test_tpu.py | 4 +- tests/trainer/test_trainer.py | 98 ++++++++++--------- tests/utilities/test_enums.py | 14 +-- 17 files changed, 185 insertions(+), 144 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a87b3b95f0659..43c0c6ab144f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) +- Deprecated `DeviceType` in favor of `_AcceleratorType` ([#10503](https://github.com/PyTorchLightning/pytorch-lightning/pull/10503)) + + - Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622)) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 7ee6771056666..088c8e650074c 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -29,7 +29,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import DeviceType, rank_zero_deprecation, rank_zero_only +from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -126,7 +126,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O if not trainer.logger: raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.") - if trainer._device_type != DeviceType.GPU: + if trainer._device_type != _AcceleratorType.GPU: raise MisconfigurationException( "You are using GPUStatsMonitor but are not running on GPU" f" since gpus attribute in Trainer is set to {trainer.gpus}." diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 20d3f1b8ba925..9c4f09c08a9b3 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -21,7 +21,7 @@ import time from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType, rank_zero_deprecation, rank_zero_info +from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: @@ -70,7 +70,7 @@ def on_train_start(self, trainer, pl_module) -> None: if not trainer.logger: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") - if trainer._device_type != DeviceType.TPU: + if trainer._device_type != _AcceleratorType.TPU: raise MisconfigurationException( "You are using XLAStatsMonitor but are not running on TPU" f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}." diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index b2adeeac4bd5b..9073f5dd54903 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -28,7 +28,7 @@ from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector -from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, @@ -448,11 +448,11 @@ def _check_strategy_support(self, strategy: Optional[Union[str, TrainingTypePlug ) @staticmethod - def _supported_device_types() -> Sequence[DeviceType]: + def _supported_device_types() -> Sequence[_AcceleratorType]: return ( - DeviceType.CPU, - DeviceType.GPU, - DeviceType.TPU, + _AcceleratorType.CPU, + _AcceleratorType.GPU, + _AcceleratorType.TPU, ) @staticmethod diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 7050ac75de8eb..b6bc1c3c25bf9 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -30,7 +30,7 @@ ) from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.utilities import AMPType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -378,7 +378,7 @@ def _optimizer_step( optimizer, opt_idx, train_step_and_backward_closure, - on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), + on_tpu=(self.trainer._device_type == _AcceleratorType.TPU and _TPU_AVAILABLE), using_native_amp=(self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE), using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index c95d46e77b977..ba1166a019e6b 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -61,10 +61,10 @@ TorchElasticEnvironment, ) from pytorch_lightning.utilities import ( + _AcceleratorType, _StrategyType, AMPType, device_parser, - DeviceType, rank_zero_deprecation, rank_zero_info, rank_zero_warn, @@ -106,7 +106,7 @@ def __init__( plugins, ): # initialization - self._device_type = DeviceType.CPU + self._device_type = _AcceleratorType.CPU self._distrib_type = None self._accelerator_type = None @@ -199,32 +199,32 @@ def _init_deterministic(self, deterministic: bool) -> None: def select_accelerator_type(self) -> None: if self.distributed_backend == "auto": if self.has_tpu: - self._accelerator_type = DeviceType.TPU + self._accelerator_type = _AcceleratorType.TPU elif self.has_ipu: - self._accelerator_type = DeviceType.IPU + self._accelerator_type = _AcceleratorType.IPU elif self.has_gpu: - self._accelerator_type = DeviceType.GPU + self._accelerator_type = _AcceleratorType.GPU else: self._set_devices_to_cpu_num_processes() - self._accelerator_type = DeviceType.CPU - elif self.distributed_backend == DeviceType.TPU: + self._accelerator_type = _AcceleratorType.CPU + elif self.distributed_backend == _AcceleratorType.TPU: if not self.has_tpu: msg = "TPUs are not available" if not _TPU_AVAILABLE else "you didn't pass `tpu_cores` to `Trainer`" raise MisconfigurationException(f"You passed `accelerator='tpu'`, but {msg}.") - self._accelerator_type = DeviceType.TPU - elif self.distributed_backend == DeviceType.IPU: + self._accelerator_type = _AcceleratorType.TPU + elif self.distributed_backend == _AcceleratorType.IPU: if not self.has_ipu: msg = "IPUs are not available" if not _IPU_AVAILABLE else "you didn't pass `ipus` to `Trainer`" raise MisconfigurationException(f"You passed `accelerator='ipu'`, but {msg}.") - self._accelerator_type = DeviceType.IPU - elif self.distributed_backend == DeviceType.GPU: + self._accelerator_type = _AcceleratorType.IPU + elif self.distributed_backend == _AcceleratorType.GPU: if not self.has_gpu: msg = "you didn't pass `gpus` to `Trainer`" if torch.cuda.is_available() else "GPUs are not available" raise MisconfigurationException(f"You passed `accelerator='gpu'`, but {msg}.") - self._accelerator_type = DeviceType.GPU - elif self.distributed_backend == DeviceType.CPU: + self._accelerator_type = _AcceleratorType.GPU + elif self.distributed_backend == _AcceleratorType.CPU: self._set_devices_to_cpu_num_processes() - self._accelerator_type = DeviceType.CPU + self._accelerator_type = _AcceleratorType.CPU if self.distributed_backend in self.accelerator_types: self.distributed_backend = None @@ -250,29 +250,29 @@ def _warn_if_devices_flag_ignored(self) -> None: if self.devices is None: return devices_warning = f"The flag `devices={self.devices}` will be ignored, as you have set" - if self.distributed_backend in ("auto", DeviceType.TPU): + if self.distributed_backend in ("auto", _AcceleratorType.TPU): if self.tpu_cores is not None: rank_zero_warn(f"{devices_warning} `tpu_cores={self.tpu_cores}`") - elif self.distributed_backend in ("auto", DeviceType.IPU): + elif self.distributed_backend in ("auto", _AcceleratorType.IPU): if self.ipus is not None: rank_zero_warn(f"{devices_warning} `ipus={self.ipus}`") - elif self.distributed_backend in ("auto", DeviceType.GPU): + elif self.distributed_backend in ("auto", _AcceleratorType.GPU): if self.gpus is not None: rank_zero_warn(f"{devices_warning} `gpus={self.gpus}`") - elif self.distributed_backend in ("auto", DeviceType.CPU): + elif self.distributed_backend in ("auto", _AcceleratorType.CPU): if self.num_processes != 1: rank_zero_warn(f"{devices_warning} `num_processes={self.num_processes}`") def _set_devices_if_none(self) -> None: if self.devices is not None: return - if self._accelerator_type == DeviceType.TPU: + if self._accelerator_type == _AcceleratorType.TPU: self.devices = self.tpu_cores - elif self._accelerator_type == DeviceType.IPU: + elif self._accelerator_type == _AcceleratorType.IPU: self.devices = self.ipus - elif self._accelerator_type == DeviceType.GPU: + elif self._accelerator_type == _AcceleratorType.GPU: self.devices = self.gpus - elif self._accelerator_type == DeviceType.CPU: + elif self._accelerator_type == _AcceleratorType.CPU: self.devices = self.num_processes def _handle_accelerator_and_strategy(self) -> None: @@ -386,7 +386,7 @@ def handle_given_plugins(self) -> None: @property def accelerator_types(self) -> List[str]: - return ["auto"] + list(DeviceType) + return ["auto"] + list(_AcceleratorType) @property def precision_plugin(self) -> PrecisionPlugin: @@ -424,7 +424,7 @@ def has_cpu(self) -> bool: @property def use_cpu(self) -> bool: - return self._accelerator_type == DeviceType.CPU + return self._accelerator_type == _AcceleratorType.CPU @property def has_gpu(self) -> bool: @@ -433,11 +433,11 @@ def has_gpu(self) -> bool: gpus = self.parallel_device_ids if gpus is not None and len(gpus) > 0: return True - return self._map_devices_to_accelerator(DeviceType.GPU) + return self._map_devices_to_accelerator(_AcceleratorType.GPU) @property def use_gpu(self) -> bool: - return self._accelerator_type == DeviceType.GPU and self.has_gpu + return self._accelerator_type == _AcceleratorType.GPU and self.has_gpu @property def has_tpu(self) -> bool: @@ -445,11 +445,11 @@ def has_tpu(self) -> bool: # `tpu_cores` to Trainer for training. if self.tpu_cores is not None: return True - return self._map_devices_to_accelerator(DeviceType.TPU) + return self._map_devices_to_accelerator(_AcceleratorType.TPU) @property def use_tpu(self) -> bool: - return self._accelerator_type == DeviceType.TPU and self.has_tpu + return self._accelerator_type == _AcceleratorType.TPU and self.has_tpu @property def tpu_id(self) -> Optional[int]: @@ -463,36 +463,36 @@ def has_ipu(self) -> bool: # `ipus` to Trainer for training. if self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin): return True - return self._map_devices_to_accelerator(DeviceType.IPU) + return self._map_devices_to_accelerator(_AcceleratorType.IPU) @property def use_ipu(self) -> bool: - return self._accelerator_type == DeviceType.IPU and self.has_ipu + return self._accelerator_type == _AcceleratorType.IPU and self.has_ipu def _set_devices_to_cpu_num_processes(self) -> None: if self.num_processes == 1: - self._map_devices_to_accelerator(DeviceType.CPU) + self._map_devices_to_accelerator(_AcceleratorType.CPU) def _map_devices_to_accelerator(self, accelerator: str) -> bool: if self.devices is None: return False - if accelerator == DeviceType.TPU and _TPU_AVAILABLE: + if accelerator == _AcceleratorType.TPU and _TPU_AVAILABLE: if self.devices == "auto": self.devices = TPUAccelerator.auto_device_count() self.tpu_cores = device_parser.parse_tpu_cores(self.devices) return True - if accelerator == DeviceType.IPU and _IPU_AVAILABLE: + if accelerator == _AcceleratorType.IPU and _IPU_AVAILABLE: if self.devices == "auto": self.devices = IPUAccelerator.auto_device_count() self.ipus = self.devices return True - if accelerator == DeviceType.GPU and torch.cuda.is_available(): + if accelerator == _AcceleratorType.GPU and torch.cuda.is_available(): if self.devices == "auto": self.devices = GPUAccelerator.auto_device_count() self.gpus = self.devices self.parallel_device_ids = device_parser.parse_gpu_ids(self.devices) return True - if accelerator == DeviceType.CPU: + if accelerator == _AcceleratorType.CPU: if self.devices == "auto": self.devices = CPUAccelerator.auto_device_count() if not isinstance(self.devices, int): @@ -829,7 +829,7 @@ def set_distributed_mode(self, strategy: Optional[str] = None): if isinstance(self.distributed_backend, Accelerator): return - is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == DeviceType.CPU + is_cpu_accelerator_type = self._accelerator_type and self._accelerator_type == _AcceleratorType.CPU _use_cpu = is_cpu_accelerator_type or self.distributed_backend and "cpu" in self.distributed_backend if self.distributed_backend is None: @@ -867,16 +867,16 @@ def set_distributed_mode(self, strategy: Optional[str] = None): self.num_processes = os.cpu_count() # special case with TPUs elif self.has_tpu and not _use_cpu: - self._device_type = DeviceType.TPU + self._device_type = _AcceleratorType.TPU if isinstance(self.tpu_cores, int): self._distrib_type = _StrategyType.TPU_SPAWN elif self.has_ipu and not _use_cpu: - self._device_type = DeviceType.IPU + self._device_type = _AcceleratorType.IPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = _StrategyType(self.distributed_backend) if self.num_gpus > 0 and not _use_cpu: - self._device_type = DeviceType.GPU + self._device_type = _AcceleratorType.GPU _gpu_distrib_types = (_StrategyType.DP, _StrategyType.DDP, _StrategyType.DDP_SPAWN, _StrategyType.DDP2) # DP and DDP2 cannot run without GPU @@ -896,13 +896,13 @@ def set_distributed_mode(self, strategy: Optional[str] = None): self.check_interactive_compatibility() # for DDP overwrite nb processes by requested GPUs - if self._device_type == DeviceType.GPU and self._distrib_type in ( + if self._device_type == _AcceleratorType.GPU and self._distrib_type in ( _StrategyType.DDP, _StrategyType.DDP_SPAWN, ): self.num_processes = self.num_gpus - if self._device_type == DeviceType.GPU and self._distrib_type == _StrategyType.DDP2: + if self._device_type == _AcceleratorType.GPU and self._distrib_type == _StrategyType.DDP2: self.num_processes = self.num_nodes # Horovod is an extra case... @@ -965,8 +965,8 @@ def has_horovodrun() -> bool: def update_device_type_if_ipu_plugin(self) -> None: # This allows the poptorch.Options that are passed into the IPUPlugin to be the source of truth, # which gives users the flexibility to not have to pass `ipus` flag directly to Trainer - if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU: - self._device_type = DeviceType.IPU + if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != _AcceleratorType.IPU: + self._device_type = _AcceleratorType.IPU def update_device_type_if_training_type_plugin_passed(self) -> None: if isinstance(self.strategy, TrainingTypePlugin) or any( @@ -974,18 +974,18 @@ def update_device_type_if_training_type_plugin_passed(self) -> None: ): if self._accelerator_type is not None: if self.use_ipu: - self._device_type = DeviceType.IPU + self._device_type = _AcceleratorType.IPU elif self.use_tpu: - self._device_type = DeviceType.TPU + self._device_type = _AcceleratorType.TPU elif self.use_gpu: - self._device_type = DeviceType.GPU + self._device_type = _AcceleratorType.GPU else: if self.has_ipu: - self._device_type = DeviceType.IPU + self._device_type = _AcceleratorType.IPU elif self.has_tpu: - self._device_type = DeviceType.TPU + self._device_type = _AcceleratorType.TPU elif self.has_gpu: - self._device_type = DeviceType.GPU + self._device_type = _AcceleratorType.GPU def _set_distrib_type_if_training_type_plugin_passed(self): # This is required as when `TrainingTypePlugin` instance is passed to either `strategy` diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index b98f13138b36f..ecd32f11df19e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -21,7 +21,7 @@ from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities import DeviceType, memory +from pytorch_lightning.utilities import _AcceleratorType, memory from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -329,7 +329,7 @@ def gpus_metrics(self) -> Dict[str, float]: .. deprecated:: v1.5 Will be removed in v1.7. """ - if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory: + if self.trainer._device_type == _AcceleratorType.GPU and self.log_gpu_memory: mem_map = memory.get_memory_profile(self.log_gpu_memory) self._gpus_metrics.update(mem_map) return self._gpus_metrics diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 667b57fd1c76f..18f13a75bf18e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -63,11 +63,11 @@ from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities import ( + _AcceleratorType, _IPU_AVAILABLE, _StrategyType, _TPU_AVAILABLE, device_parser, - DeviceType, GradClipAlgorithmType, parsing, rank_zero_deprecation, @@ -1519,26 +1519,32 @@ def __setup_profiler(self) -> None: self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) def _log_device_info(self) -> None: - rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}") + rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}") - num_tpu_cores = self.tpu_cores if self.tpu_cores is not None and self._device_type == DeviceType.TPU else 0 + num_tpu_cores = ( + self.tpu_cores if self.tpu_cores is not None and self._device_type == _AcceleratorType.TPU else 0 + ) rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") num_ipus = self.ipus if self.ipus is not None else 0 rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") - if torch.cuda.is_available() and self._device_type != DeviceType.GPU: + if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU: rank_zero_warn( "GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`." ) - if _TPU_AVAILABLE and self._device_type != DeviceType.TPU: + if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU: rank_zero_warn( "TPU available but not used. Set the `tpu_cores` flag in your trainer" " `Trainer(tpu_cores=8)` or script `--tpu_cores=8`." ) - if _IPU_AVAILABLE and self._device_type != DeviceType.IPU and not isinstance(self.accelerator, IPUAccelerator): + if ( + _IPU_AVAILABLE + and self._device_type != _AcceleratorType.IPU + and not isinstance(self.accelerator, IPUAccelerator) + ): rank_zero_warn( "IPU available but not used. Set the `ipus` flag in your trainer" " `Trainer(ipus=8)` or script `--ipus=8`." @@ -1595,7 +1601,7 @@ def _distrib_type(self) -> _StrategyType: return self._accelerator_connector._distrib_type @property - def _device_type(self) -> DeviceType: + def _device_type(self) -> _AcceleratorType: return self._accelerator_connector._device_type @property diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 22164908a3e3f..48a18db121d92 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -18,9 +18,9 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401 from pytorch_lightning.utilities.enums import ( # noqa: F401 + _AcceleratorType, _StrategyType, AMPType, - DeviceType, DistributedType, GradClipAlgorithmType, LightningEnum, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 1d7a6e3fa5452..51eb02c018260 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -143,17 +143,12 @@ def deprecate(self) -> None: ) -class DeviceType(LightningEnum): - """Define Device type by its nature - acceleatrors. +class DeviceType(LightningEnum, metaclass=_OnAccessEnumMeta): + """Define Device type by its nature - accelerators. - >>> DeviceType.CPU == DeviceType.from_str('cpu') - True - >>> # you can match the type with string - >>> DeviceType.GPU == 'GPU' - True - >>> # which is case invariant - >>> DeviceType.TPU in ('tpu', 'CPU') - True + Deprecated since v1.6.0 and will be removed in v1.8.0. + + Use `_AcceleratorType` instead. """ CPU = "CPU" @@ -161,6 +156,12 @@ class DeviceType(LightningEnum): IPU = "IPU" TPU = "TPU" + def deprecate(self) -> None: + rank_zero_deprecation( + "`DeviceType` Enum has been deprecated in v1.6 and will be removed in v1.8." + " Use the string value `{self.value!r}` instead." + ) + class GradClipAlgorithmType(LightningEnum): """Define gradient_clip_algorithm types - training-tricks. @@ -260,6 +261,25 @@ def is_interactive_compatible(self) -> bool: return self in _StrategyType.interactive_compatible_types() +class _AcceleratorType(LightningEnum): + """Define Accelerator type by its nature. + + >>> _AcceleratorType.CPU == _AcceleratorType.from_str('cpu') + True + >>> # you can match the type with string + >>> _AcceleratorType.GPU == 'GPU' + True + >>> # which is case invariant + >>> _AcceleratorType.TPU in ('tpu', 'CPU') + True + """ + + CPU = "CPU" + GPU = "GPU" + IPU = "IPU" + TPU = "TPU" + + class _FaultTolerantMode(LightningEnum): DISABLED = "disabled" diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index bab6da5368b65..37ff258436568 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -23,7 +23,7 @@ from torch.utils.hooks import RemovableHandle import pytorch_lightning as pl -from pytorch_lightning.utilities import AMPType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, AMPType from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.warnings import WarningCache @@ -261,7 +261,11 @@ def _forward_example_input(self) -> None: input_ = model.example_input_array input_ = model._apply_batch_transfer_handler(input_) - if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU: + if ( + trainer is not None + and trainer.amp_backend == AMPType.NATIVE + and trainer._device_type != _AcceleratorType.TPU + ): model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index a9c9c50d80168..c95c7dc517ef0 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -43,7 +43,7 @@ SLURMEnvironment, TorchElasticEnvironment, ) -from pytorch_lightning.utilities import _StrategyType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -729,7 +729,7 @@ def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin): trainer = Trainer(strategy=plugin(), gpus=2) assert isinstance(trainer.training_type_plugin, plugin) - assert trainer._device_type == DeviceType.GPU + assert trainer._device_type == _AcceleratorType.GPU assert isinstance(trainer.accelerator, GPUAccelerator) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index be2e597c9a2f9..524e122478bad 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _IPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -571,7 +571,7 @@ def test_device_type_when_training_plugin_ipu_passed(tmpdir): trainer = Trainer(strategy=IPUPlugin(), ipus=8) assert isinstance(trainer.training_type_plugin, IPUPlugin) - assert trainer._device_type == DeviceType.IPU + assert trainer._device_type == _AcceleratorType.IPU assert isinstance(trainer.accelerator, IPUAccelerator) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index f668f63b9f450..0c32773b56e1c 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -14,10 +14,16 @@ """Test deprecated functionality which will be removed in v1.8.0.""" import pytest -from pytorch_lightning.utilities.enums import DistributedType +from pytorch_lightning.utilities.enums import DeviceType, DistributedType def test_v1_8_0_deprecated_distributed_type_enum(): with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): _ = DistributedType.DDP + + +def test_v1_8_0_deprecated_device_type_enum(): + + with pytest.deprecated_call(match="has been deprecated in v1.6 and will be removed in v1.8."): + _ = DeviceType.CPU diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index bb4c1d017d3ec..ea8d430918e3f 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -26,7 +26,7 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync -from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -474,5 +474,5 @@ def test_device_type_when_training_plugin_tpu_passed(tmpdir): trainer = Trainer(strategy=TPUSpawnPlugin(), tpu_cores=8) assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) - assert trainer._device_type == DeviceType.TPU + assert trainer._device_type == _AcceleratorType.TPU assert isinstance(trainer.accelerator, TPUAccelerator) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2d39d83ec38ab..6004d4540a85f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -48,7 +48,7 @@ DDPSpawnShardedPlugin, ) from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _StrategyType, DeviceType +from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -1149,75 +1149,75 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): [ ( dict(accelerator=None, gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="dp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp", num_nodes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator="ddp_cpu", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp2", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(accelerator=None, gpus=1), - dict(_distrib_type=None, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="dp", gpus=1), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="ddp", gpus=1), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator="ddp_cpu", num_processes=2, gpus=1), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="ddp2", gpus=1), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(accelerator=None, gpus=2), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(accelerator="dp", gpus=2), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(accelerator="ddp", gpus=2), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(accelerator="ddp2", gpus=2), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(accelerator="ddp2", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(accelerator="dp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ], ) @@ -2091,118 +2091,118 @@ def training_step(self, batch, batch_idx): [ ( dict(strategy=None, gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="dp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp", num_nodes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp2", gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy=None, gpus=1), - dict(_distrib_type=None, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="dp", gpus=1), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp", gpus=1), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp_spawn", gpus=1), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy="ddp2", gpus=1), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1), ), ( dict(strategy=None, gpus=2), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(strategy="dp", gpus=2), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy="ddp", gpus=2), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=2), ), ( dict(strategy="ddp2", gpus=2), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy="ddp2", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="dp", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp_spawn", num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy="ddp_spawn", num_processes=1, gpus=None), - dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + dict(_distrib_type=None, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=1), ), ( dict(strategy="ddp_fully_sharded", gpus=1), dict( _distrib_type=_StrategyType.DDP_FULLY_SHARDED, - _device_type=DeviceType.GPU, + _device_type=_AcceleratorType.GPU, num_gpus=1, num_processes=1, ), ), ( dict(strategy=DDPSpawnPlugin(), num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy=DDPSpawnPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP_SPAWN, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDPPlugin(), num_processes=2, gpus=None), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.CPU, num_gpus=0, num_processes=2), ), ( dict(strategy=DDPPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDP2Plugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DDP2, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DataParallelPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict(_distrib_type=_StrategyType.DP, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1), ), ( dict(strategy=DDPFullyShardedPlugin(), gpus=2), dict( _distrib_type=_StrategyType.DDP_FULLY_SHARDED, - _device_type=DeviceType.GPU, + _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1, ), @@ -2211,14 +2211,16 @@ def training_step(self, batch, batch_idx): dict(strategy=DDPSpawnShardedPlugin(), gpus=2), dict( _distrib_type=_StrategyType.DDP_SHARDED_SPAWN, - _device_type=DeviceType.GPU, + _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1, ), ), ( dict(strategy=DDPShardedPlugin(), gpus=2), - dict(_distrib_type=_StrategyType.DDP_SHARDED, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + dict( + _distrib_type=_StrategyType.DDP_SHARDED, _device_type=_AcceleratorType.GPU, num_gpus=2, num_processes=1 + ), ), ], ) diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py index 4f902e2238d1c..99158e2e83c79 100644 --- a/tests/utilities/test_enums.py +++ b/tests/utilities/test_enums.py @@ -13,17 +13,17 @@ # limitations under the License. import pytest -from pytorch_lightning.utilities.enums import DeviceType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType +from pytorch_lightning.utilities.enums import _AcceleratorType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType def test_consistency(): - assert DeviceType.TPU not in ("GPU", "CPU") - assert DeviceType.TPU in ("TPU", "CPU") - assert DeviceType.TPU in ("tpu", "CPU") - assert DeviceType.TPU not in {"GPU", "CPU"} + assert _AcceleratorType.TPU not in ("GPU", "CPU") + assert _AcceleratorType.TPU in ("TPU", "CPU") + assert _AcceleratorType.TPU in ("tpu", "CPU") + assert _AcceleratorType.TPU not in {"GPU", "CPU"} # hash cannot be case invariant - assert DeviceType.TPU not in {"TPU", "CPU"} - assert DeviceType.TPU in {"tpu", "CPU"} + assert _AcceleratorType.TPU not in {"TPU", "CPU"} + assert _AcceleratorType.TPU in {"tpu", "CPU"} def test_precision_supported_types(): From 3d6262b7a91215e72019d720e742e6261e1636dc Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 25 Nov 2021 17:31:53 +0000 Subject: [PATCH 40/59] Fault Tolerant Manual: Add support for DDP (#10638) --- CHANGELOG.md | 3 +++ .../loops/epoch/evaluation_epoch_loop.py | 15 +++++++++++--- .../loops/epoch/training_epoch_loop.py | 6 ++++-- .../trainer/connectors/data_connector.py | 2 ++ pytorch_lightning/trainer/supporters.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 15 ++++++++++---- pytorch_lightning/utilities/auto_restart.py | 13 ++++++++++++ pytorch_lightning/utilities/distributed.py | 20 +++++++++---------- pytorch_lightning/utilities/imports.py | 5 +++-- tests/utilities/test_auto_restart.py | 8 ++++++++ tests/utilities/test_distributed.py | 9 ++++----- 11 files changed, 74 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43c0c6ab144f2..3a136fe023084 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added + - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) @@ -21,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) * Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707)) + * Broadcast the `_terminate_gracefully` to all processes and add support for DDP ([#10638](https://github.com/PyTorchLightning/pytorch-lightning/issues/10638)) + - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 2fc572ea252e6..b7bfc1e0ed8a2 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,8 +22,13 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _update_dataloader_iter from pytorch_lightning.trainer.progress import BatchProgress -from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState +from pytorch_lightning.utilities.auto_restart import ( + _collect_states_on_rank_zero_over_collection, + _reload_dataloader_state_dict, + MergedIteratorState, +) from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -173,12 +178,16 @@ def on_save_checkpoint(self) -> Dict: state_to_save = "state" if self._has_completed() else "previous_state" state: Optional[MergedIteratorState] = getattr(self._data_fetcher.dataloader_iter, state_to_save, None) if state: - state_dict["dataloader_state_dict"] = asdict(state) + state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(asdict(state)) return state_dict def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available - self._dataloader_state_dict = state_dict.get("dataloader_state_dict") + # dataset states are collected across all ranks + dataloader_state_dict = state_dict.get("dataloader_state_dict", None) + if not _fault_tolerant_training() or not dataloader_state_dict: + return + self._dataloader_state_dict = dataloader_state_dict[self.trainer.global_rank] def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8ddca3ad505e8..a75ad470c29ef 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -25,6 +25,7 @@ from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden @@ -320,8 +321,9 @@ def on_save_checkpoint(self) -> Dict: or self.batch_progress.current.ready == 0 # did not start ): return state_dict - state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict( - has_completed=self._has_completed() + + state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection( + self.trainer.train_dataloader.state_dict(has_completed=self._has_completed()) ) return state_dict diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index deee64c90fe2e..e6f76e0403bd7 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -19,6 +19,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.auto_restart import _teardown_dataloader_get_iterators from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -254,6 +255,7 @@ def teardown(self) -> None: if self.sanity_check_data_fetcher: self.sanity_check_data_fetcher.teardown() self.sanity_check_data_fetcher = None + _teardown_dataloader_get_iterators() @dataclass diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index d65bc08e6689e..df86ea157f3ac 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -29,6 +29,7 @@ patch_dataloader_iterator, ) from pytorch_lightning.utilities.data import get_len +from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -403,6 +404,10 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader + # dataset states are collected across all ranks + rank = torch.distributed.get_rank() if distributed_available() else 0 + state_dict = state_dict[rank] + _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 18f13a75bf18e..73e9437040ae2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2100,10 +2100,17 @@ def _results(self) -> Optional[ResultCollection]: return active_loop._results def _exit_gracefully_on_signal(self) -> None: - if _fault_tolerant_training() and self._terminate_gracefully: - caller = inspect.stack()[1] - class_name = caller[0].f_locals["self"].__class__.__name__ - raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}") + if not _fault_tolerant_training(): + return + if not self._should_terminated_gracefully(): + return + caller = inspect.stack()[1] + class_name = caller[0].f_locals["self"].__class__.__name__ + raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}") + + def _should_terminated_gracefully(self) -> bool: + value = torch.tensor(self._terminate_gracefully, device=self.training_type_plugin.root_device) + return self.training_type_plugin.reduce(value, reduce_op="sum") > 0 @property def weights_summary(self) -> Optional[str]: diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 9f99634bd15b2..84f0c9decefea 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -31,6 +31,8 @@ from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -737,3 +739,14 @@ def _teardown_dataloader_get_iterators() -> None: if get_iterator: DataLoader._get_iterator = get_iterator del DataLoader._ori_get_iterator + + +def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: + """This utility collects the state across processes for a collection of state.""" + + def fn(state: Dict): + if key in state: + return _collect_states_on_rank_zero(state) + return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()} + + return apply_to_collection(state_dict, Dict, fn) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7c6e4f4048181..5612752569f7a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -378,7 +378,14 @@ def init_dist_connection( ) -def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]: +def _broadcast_object_list(obj: Any, rank: int) -> Any: + objects = [obj if torch.distributed.get_rank() == rank else None] + torch.distributed.broadcast_object_list(objects, src=rank) + return objects[0] + + +# TODO: Refactor with the Strategy Collectives once finalized. +def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]: """This distributed utility collects dictionary state across all processes. Args: @@ -391,13 +398,4 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> """ if not distributed_available(): return {0: state} - states = {} - current_rank = torch.distributed.get_rank() - for rank in range(1, torch.distributed.get_world_size()): - objects = [state if current_rank == rank else None] - torch.distributed.broadcast_object_list(objects, src=rank, device=device) - states[rank] = objects[0] - if current_rank != 0: - return None - states[0] = state - return states + return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())} diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index aa6349b5d677a..49c94d87e64c2 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -14,7 +14,6 @@ """General utilities.""" import importlib import operator -import os import platform import sys from importlib.util import find_spec @@ -111,4 +110,6 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: # experimental feature within PyTorch Lightning. def _fault_tolerant_training() -> bool: - return bool(int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))) + from pytorch_lightning.utilities.enums import _FaultTolerantMode + + return _FaultTolerantMode.detect_current_mode().is_enabled diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c69b70b65b13c..58a11d0de69db 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,6 +39,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, + _collect_states_on_rank_zero_over_collection, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, _reload_dataloader_state_dict, @@ -1254,6 +1255,13 @@ def load_state_dict(self, state_dict): self.counter = state_dict[0]["counter"] +def test_collect_states_with_collection(): + state = {"state": 0} + collection = [{"a": state, "b": [{"a": state}]}] + generated = _collect_states_on_rank_zero_over_collection(collection) + assert generated == [{"a": {0: state}, "b": [{"a": {0: state}}]}] + + @pytest.mark.parametrize("num_workers", [0]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index a48b4486a470f..6226aadecbf9e 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -64,15 +64,14 @@ def foo(): def _test_collect_states(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" + torch.cuda.set_device(f"cuda:{rank}") + # initialize the process group torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) state = {"something": torch.tensor([rank])} - collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}")) - if rank == 0: - assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} - else: - assert collected_state is None + collected_state = _collect_states_on_rank_zero(state) + assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} @RunIf(skip_windows=True, min_gpus=2, min_torch="1.10") From e507bc902703cc6e966e11d2123e45456169b7dc Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 26 Nov 2021 14:45:22 +0530 Subject: [PATCH 41/59] Fix compare version for packages (#10762) --- CHANGELOG.md | 2 ++ pytorch_lightning/utilities/imports.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a136fe023084..43e36c8adb681 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -199,6 +199,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) +- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) + ## [1.5.2] - 2021-11-16 diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 49c94d87e64c2..daffe817f87dd 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -58,7 +58,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: pkg_version = Version(pkg.__version__) else: # try pkg_resources to infer version - pkg_version = Version(pkg_resources.get_distribution(pkg).version) + pkg_version = Version(pkg_resources.get_distribution(package).version) except TypeError: # this is mocked by Sphinx, so it should return True to generate all summaries return True From 412d507a73c79f5e4f7ef14289cefe2eb2af94a7 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 26 Nov 2021 13:37:27 +0000 Subject: [PATCH 42/59] Fault Tolerant: move signal to SIGTERM (#10605) --- .github/workflows/ci_test-conda.yml | 2 +- CHANGELOG.md | 3 + pl_examples/fault_tolerant/automatic.py | 137 ++++++++++++++++++ pl_examples/run_examples.sh | 2 + .../trainer/connectors/signal_connector.py | 5 +- pytorch_lightning/trainer/trainer.py | 14 +- pytorch_lightning/utilities/exceptions.py | 4 +- .../connectors/test_signal_connector.py | 7 +- tests/utilities/test_all_gather_grad.py | 13 ++ tests/utilities/test_auto_restart.py | 18 ++- 10 files changed, 184 insertions(+), 21 deletions(-) create mode 100644 pl_examples/fault_tolerant/automatic.py diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index a24a15baf4d78..bc75921dc61fc 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -43,7 +43,7 @@ jobs: - name: Tests run: | # NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pytorch_lightning -m pytest --random-order-seed=1 pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + coverage run --source pytorch_lightning -m pytest --random-order-seed=2 pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest results diff --git a/CHANGELOG.md b/CHANGELOG.md index 43e36c8adb681..c351b93d43b9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) +- Fault Tolerant relies on `signal.SIGTERM` to gracefully exit instead of `signal.SIGUSR1` ([#10605](https://github.com/PyTorchLightning/pytorch-lightning/pull/10605)) + + - Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541)) diff --git a/pl_examples/fault_tolerant/automatic.py b/pl_examples/fault_tolerant/automatic.py new file mode 100644 index 0000000000000..b3e7dc6db3a49 --- /dev/null +++ b/pl_examples/fault_tolerant/automatic.py @@ -0,0 +1,137 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Here is an example of `Lightning Fault Tolerant Automatic`. + +Find the documentation: https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html + +RUN WITHOUT FAILURE: + + 1. Launch `python pl_examples/fault_tolerant/automatic.py`. + - You should see `[-1.1343, 0.0186]` in the logs. + +RUN WITH SIMULATED FAILURE: + + 1. Launch `python pl_examples/fault_tolerant/automatic.py --emulate_kill_signal`. + - You should see `kill -SIGTERM {PID}` in the logs. + 2. Run this command within another terminal. + - You should see `Received signal 15. Saving a fault-tolerant checkpoint and terminating.` in the logs. + 3. Launch `python pl_examples/fault_tolerant/automatic.py --emulate_kill_signal` again. + - You should see `Restored all states from the checkpoint file at ./.pl_auto_save.ckpt` + - And you should see `[-1.1343, 0.0186]` in the logs. + + To restart the process, just run `rm .pl_auto_save.ckpt` to delete the auto restart checkpoint. + +This example shows that the weights trained with failure matches the weight trained without failure, +thus the training has been properly resumed whilst being fully reproducible. + +Used PyTorch 1.7.1. +""" + +import os +import random as python_random +from argparse import ArgumentParser +from time import sleep + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import _logger as log +from pytorch_lightning import LightningModule, seed_everything, Trainer + + +class RandomGetItemDataset(Dataset): + """A dataset with random elements generated using global rng from torch, numpy and python.""" + + def __init__(self, length, size): + self.size = size + self.len = length + + def __getitem__(self, index): + t = torch.rand(self.size) + n = torch.from_numpy(np.random.rand(self.size)) + p = torch.tensor([python_random.random() for _ in range(self.size)]) + sample = (index + (t + n + p) / 10).float() + return sample + + def __len__(self): + return self.len + + +class SimpleMLP(LightningModule): + def __init__(self, fail_on_step: int = -1): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.seen_batches = [] + self.fail_on_step = fail_on_step + + def training_step(self, batch, batch_idx): + if self.global_step == self.fail_on_step: + log.info( + f"READY TO BE KILLED WITH SIGTERM SIGNAL. " f"Run `kill -SIGTERM {os.getpid()}` in another terminal." + ) + # this line is used to wait for you to send the signal to exit gracefully. + while not self.trainer._terminate_gracefully: + sleep(0.1) + batch = batch["data"] if isinstance(batch, dict) else batch + self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch) + loss = sum(self.layer(b).sum() for b in batch) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def train_dataloader(self): + return DataLoader(RandomGetItemDataset(3, 1)) + + +def _run_training(default_root_dir=".", max_epochs=3, fail_on_step: int = -1, ckpt_path=None): + model = SimpleMLP(fail_on_step=fail_on_step) + trainer = Trainer(default_root_dir=default_root_dir, max_epochs=max_epochs) + trainer.fit(model, ckpt_path=ckpt_path) + return model.seen_batches, model.parameters() + + +def main(args): + seed_everything(42) + os.environ["PL_FAULT_TOLERANT_TRAINING"] = "automatic" # active fault tolerant automatic + + ckpt_path = ".pl_auto_save.ckpt" + auto_restart_ckpt_path_exists = os.path.exists(ckpt_path) + if args.emulate_kill_signal: + fail_on_step = -1 if auto_restart_ckpt_path_exists else 4 + completed_batches = 4 if auto_restart_ckpt_path_exists else 5 + else: + fail_on_step = -1 + completed_batches = 9 + + complete_batches, weights = _run_training(fail_on_step=fail_on_step) + assert len(complete_batches) == completed_batches + + if not auto_restart_ckpt_path_exists and args.emulate_kill_signal: + assert os.path.exists(ckpt_path) + + if auto_restart_ckpt_path_exists or not args.emulate_kill_signal: + log.info([w for w in weights]) + + +if __name__ == "__main__": + parser = ArgumentParser(description="Fault Tolerant Under Signal Example") + parser.add_argument( + "--emulate_kill_signal", + action="store_true", + help="Whether you should gracefully kill the process with a `SIGTERM` signal.", + ) + main(parser.parse_args()) diff --git a/pl_examples/run_examples.sh b/pl_examples/run_examples.sh index a04a57631d9cb..792894137a463 100755 --- a/pl_examples/run_examples.sh +++ b/pl_examples/run_examples.sh @@ -35,3 +35,5 @@ args=" python "${dir_path}/basic_examples/mnist_examples/image_classifier_4_lightning_module.py" ${args} "$@" python "${dir_path}/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py" ${args} "$@" + +python "${dir_path}/fault_tolerant/automatic.py" diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 90d0f6928283f..189f81eb5c8cc 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -36,7 +36,7 @@ def register_signal_handlers(self) -> None: sigterm_handlers: List[Callable] = [] if _fault_tolerant_training(): - sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) + sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn) environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: @@ -83,7 +83,8 @@ def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: if self.trainer.logger: self.trainer.logger.finalize("finished") - def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: + def fault_tolerant_sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: + log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.") self.trainer._terminate_gracefully = True def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 73e9437040ae2..9bdd658968b77 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1550,7 +1550,7 @@ def _log_device_info(self) -> None: " `Trainer(ipus=8)` or script `--ipus=8`." ) - def _on_exception(self): + def _on_exception(self) -> None: if not _fault_tolerant_training(): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. @@ -2100,16 +2100,12 @@ def _results(self) -> Optional[ResultCollection]: return active_loop._results def _exit_gracefully_on_signal(self) -> None: - if not _fault_tolerant_training(): - return - if not self._should_terminated_gracefully(): + if not _fault_tolerant_training() or not self._should_terminate_gracefully(): return - caller = inspect.stack()[1] - class_name = caller[0].f_locals["self"].__class__.__name__ - raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}") + raise ExitGracefullyException(0) - def _should_terminated_gracefully(self) -> bool: - value = torch.tensor(self._terminate_gracefully, device=self.training_type_plugin.root_device) + def _should_terminate_gracefully(self) -> bool: + value = torch.tensor(int(self._terminate_gracefully), device=self.training_type_plugin.root_device) return self.training_type_plugin.reduce(value, reduce_op="sum") > 0 @property diff --git a/pytorch_lightning/utilities/exceptions.py b/pytorch_lightning/utilities/exceptions.py index 9afa2968ed831..94e2c219d9895 100644 --- a/pytorch_lightning/utilities/exceptions.py +++ b/pytorch_lightning/utilities/exceptions.py @@ -21,8 +21,8 @@ class DeadlockDetectedException(Exception): """Exception used when a deadlock has been detected and processes are being killed.""" -class ExitGracefullyException(Exception): - """Exception used when a ``signal.SIGUSR1`` is sent to the process. +class ExitGracefullyException(SystemExit): + """Exception used when a ``signal.SIGTERM`` is sent to the process. This signals Lightning to try to create a fault-tolerance checkpoint once the current batch or epoch is reached (assuming it can be done under 30 sec). After the checkpoint is saved, Lightning will exit. diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index fbfce158e3675..636fddd27d738 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -37,19 +37,18 @@ def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpd def handler(*_): pass - signal.signal(signal.SIGUSR1, handler) + signal.signal(signal.SIGTERM, handler) class TestModel(BoringModel): def training_step(self, batch, batch_idx): if terminate_gracefully or register_handler: - os.kill(os.getpid(), signal.SIGUSR1) + os.kill(os.getpid(), signal.SIGTERM) sleep(0.1) return super().training_step(batch, batch_idx) model = TestModel() with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}): - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0) if terminate_gracefully and not register_handler: with pytest.raises(ExitGracefullyException): @@ -59,7 +58,7 @@ def training_step(self, batch, batch_idx): assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully) # reset the signal to system defaults - signal.signal(signal.SIGUSR1, signal.SIG_DFL) + signal.signal(signal.SIGTERM, signal.SIG_DFL) @RunIf(skip_windows=True) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 2ed42b0b0f21a..0ecafa347e574 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import sys diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 58a11d0de69db..1a479af05aa3f 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math import os import random @@ -53,7 +54,7 @@ MergedIteratorState, ) from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys -from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers.boring_model import BoringModel, RandomDataset @@ -1043,7 +1044,7 @@ def __init__(self, should_signal: bool, failure_on_step: bool, failure_on_traini def _signal(self): if self.should_signal: - # simulate `os.kill(os.getpid(), signal.SIGUSR1)` + # simulate `os.kill(os.getpid(), signal.SIGTERM)` self.trainer._terminate_gracefully = True def training_step(self, batch, batch_idx): @@ -1093,7 +1094,18 @@ def _fit_model( num_sanity_val_steps=0, ) - trainer = Trainer(**trainer_kwargs) + class ExitGracefullyException(Exception): + pass + + class TestTrainer(Trainer): + def _exit_gracefully_on_signal(self) -> None: + if not _fault_tolerant_training() or not self._should_terminate_gracefully(): + return + caller = inspect.stack()[1] + class_name = caller[0].f_locals["self"].__class__.__name__ + raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}") + + trainer = TestTrainer(**trainer_kwargs) if should_signal: with pytest.raises(ExitGracefullyException, match=status): trainer.fit(model) From 6fe6e9e4143424810846674f598cd91127bd6bdd Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 26 Nov 2021 17:07:57 +0000 Subject: [PATCH 43/59] Delete TensorBoardLogger experiment before spawning the processes. (#10777) --- CHANGELOG.md | 3 +++ .../plugins/training_type/ddp_spawn.py | 17 +++++++++++++++++ .../plugins/training_type/tpu_spawn.py | 9 --------- tests/loggers/test_tensorboard.py | 15 +++++++++++++++ 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c351b93d43b9d..f52369b443164 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) +- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777)) + + ## [1.5.2] - 2021-11-16 ### Fixed diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index da724944ade7e..e09100b77207d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -25,6 +25,7 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment @@ -147,14 +148,17 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] def start_evaluating(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) def start_predicting(self, trainer: "pl.Trainer") -> None: + self._clean_logger(trainer) self.spawn(self.new_process, trainer, self.mp_queue, return_result=False) def spawn(self, function: Callable, *args: Any, return_result: bool = True, **kwargs: Any) -> Optional[Any]: @@ -415,3 +419,16 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() + + @staticmethod + def _clean_logger(trainer: "pl.Trainer") -> None: + loggers = trainer.logger._logger_iterable if isinstance(trainer.logger, LoggerCollection) else [trainer.logger] + for logger in loggers: + if isinstance(logger, TensorBoardLogger) and logger._experiment is not None: + rank_zero_warn( + "When using `ddp_spawn`, the `TensorBoardLogger` experiment should be `None`. Setting it to `None`." + ) + # the experiment class of `TensorBoard` holds a multiprocessing queue which can make ours hang. + # we want to make sure these are closed before we spawn our own threads. + # assuming nothing else references the experiment object, python should instantly `__del__` it. + logger._experiment = None diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 3ab9a8171aac5..9c8f7f18230b8 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -258,10 +258,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output - def _close_logger(self, trainer) -> None: - if trainer.logger is not None: - trainer.logger.finalize("success") - def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: return { "nprocs": len(self.parallel_devices), @@ -297,13 +293,8 @@ def start_training(self, trainer: "pl.Trainer") -> None: # todo: precision pluging is call in accelerator setup and should be moved if "XLA_USE_BF16" in os.environ: del os.environ["XLA_USE_BF16"] - self._close_logger(trainer) return super().start_training(trainer) - def start_evaluating(self, trainer: "pl.Trainer") -> None: - self._close_logger(trainer) - return super().start_evaluating(trainer) - def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 02a809aa2ab30..0a99c058ef941 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -25,6 +25,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers.base import LoggerCollection from pytorch_lightning.utilities.imports import _compare_version from tests.helpers import BoringModel @@ -332,3 +333,17 @@ def test_tensorboard_missing_folder_warning(tmpdir, caplog): assert logger.version == 0 assert "Missing logger folder:" in caplog.text + + +@pytest.mark.parametrize("use_list", [False, True]) +def test_tensorboard_ddp_spawn_cleanup(use_list, tmpdir): + tensorboard_logger = TensorBoardLogger(save_dir=tmpdir) + assert tensorboard_logger._experiment is None + tensorboard_logger.experiment # this property access will create the experiment + assert tensorboard_logger._experiment is not None + logger = [tensorboard_logger] if use_list else tensorboard_logger + trainer = Trainer(strategy="ddp_spawn", devices=2, accelerator="auto", logger=logger) + trainer.training_type_plugin._clean_logger(trainer) + if use_list: + assert isinstance(trainer.logger, LoggerCollection) + assert tensorboard_logger._experiment is None From 152eb57defc37ca09478a344a25777dd164a7452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Nov 2021 18:13:14 +0100 Subject: [PATCH 44/59] Rename special to standalone (#10779) --- .azure-pipelines/gpu-tests.yml | 4 +- .../test_accelerator_connector.py | 2 +- tests/accelerators/test_ddp.py | 2 +- tests/accelerators/test_multi_nodes_gpu.py | 4 +- tests/callbacks/test_pruning.py | 2 +- tests/callbacks/test_stochastic_weight_avg.py | 2 +- tests/callbacks/test_tqdm_progress_bar.py | 2 +- .../test_checkpoint_callback_frequency.py | 2 +- tests/conftest.py | 8 +-- tests/core/test_metric_result_integration.py | 2 +- tests/helpers/runif.py | 12 ++-- tests/lite/test_lite.py | 2 +- tests/lite/test_parity.py | 2 +- tests/models/test_hooks.py | 4 +- tests/models/test_sync_batchnorm.py | 2 +- .../environments/torch_elastic_deadlock.py | 2 +- tests/plugins/test_amp_plugins.py | 2 +- ..._ddp_fully_sharded_with_full_state_dict.py | 6 +- tests/plugins/test_ddp_plugin.py | 4 +- .../plugins/test_ddp_plugin_with_comm_hook.py | 10 ++-- tests/plugins/test_deepspeed_plugin.py | 58 +++++++++---------- tests/plugins/test_sharded_plugin.py | 6 +- tests/profiler/test_profiler.py | 6 +- .../{special_tests.sh => standalone_tests.sh} | 14 ++--- .../logging_/test_train_loop_logging.py | 2 +- .../optimization/test_manual_optimization.py | 4 +- tests/trainer/optimization/test_optimizers.py | 2 +- tests/trainer/test_trainer.py | 6 +- tests/utilities/test_all_gather_grad.py | 4 +- .../test_deepspeed_collate_checkpoint.py | 2 +- tests/utilities/test_meta.py | 2 +- tests/utilities/test_warnings.py | 4 +- 32 files changed, 93 insertions(+), 93 deletions(-) rename tests/{special_tests.sh => standalone_tests.sh} (82%) diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 71332a840fdb0..8752e8584439a 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -72,10 +72,10 @@ jobs: displayName: 'Testing: standard' - bash: | - bash tests/special_tests.sh + bash tests/standalone_tests.sh env: PL_USE_MOCKED_MNIST: "1" - displayName: 'Testing: special' + displayName: 'Testing: standalone' - bash: | python -m coverage report diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index c95c7dc517ef0..51316c155368c 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -337,7 +337,7 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@RunIf(skip_windows=True, special=True) +@RunIf(skip_windows=True, standalone=True) def test_accelerator_choice_ddp_cpu_and_strategy(tmpdir): """Test that accelerator="ddp_cpu" can work together with an instance of DDPPlugin.""" _test_accelerator_choice_ddp_cpu_and_strategy(tmpdir, ddp_strategy_class=DDPPlugin) diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 1982e967c21ea..db2f388971c12 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -108,7 +108,7 @@ def setup(self, stage: Optional[str] = None) -> None: trainer.fit(model) -@RunIf(min_gpus=2, min_torch="1.8.1", special=True) +@RunIf(min_gpus=2, min_torch="1.8.1", standalone=True) @pytest.mark.parametrize("precision", (16, 32)) def test_ddp_wrapper(tmpdir, precision): """Test parameters to ignore are carried over for DDP.""" diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index 0df49a41b0fd0..09f632746b1dd 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -31,7 +31,7 @@ # TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml) # use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)` @pytest.mark.skip("Multi-node testing is currently disabled") -@RunIf(special=True) +@RunIf(standalone=True) def test_logging_sync_dist_true_ddp(tmpdir): """Tests to ensure that the sync_dist flag works with CPU (should just return the original value)""" fake_result = 1 @@ -68,7 +68,7 @@ def validation_step(self, batch, batch_idx): # TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml) # use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)` @pytest.mark.skip("Multi-node testing is currently disabled") -@RunIf(special=True) +@RunIf(standalone=True) def test__validation_step__log(tmpdir): """Tests that validation_step can log.""" diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index ec4dcddf777c0..f63892df94310 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -160,7 +160,7 @@ def test_pruning_callback( ) -@RunIf(special=True, min_gpus=2) +@RunIf(standalone=True, min_gpus=2) @pytest.mark.parametrize("parameters_to_prune", (False, True)) @pytest.mark.parametrize("use_global_unstructured", (False, True)) def test_pruning_callback_ddp(tmpdir, parameters_to_prune, use_global_unstructured): diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index d30edb177ed10..584e24bb71ed9 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -138,7 +138,7 @@ def train_with_swa( assert trainer.lightning_module == model -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_swa_callback_ddp(tmpdir): train_with_swa(tmpdir, strategy="ddp", gpus=2) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 1ff1a602fe3b6..ba66ad169f473 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -512,7 +512,7 @@ def test_tqdm_progress_bar_can_be_pickled(): pickle.dumps(bar) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) @pytest.mark.parametrize( ["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"], [(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)], diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index fd5c76b2faef7..2c14c7de29b9c 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -87,7 +87,7 @@ def training_step(self, batch, batch_idx): @mock.patch("torch.save") -@RunIf(special=True, min_gpus=2) +@RunIf(standalone=True, min_gpus=2) @pytest.mark.parametrize(["k", "epochs", "val_check_interval", "expected"], [(1, 1, 1.0, 1), (2, 2, 0.3, 4)]) def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): class TestModel(BoringModel): diff --git a/tests/conftest.py b/tests/conftest.py index b001894f97918..176cc4342ee17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -172,13 +172,13 @@ def single_process_pg(): def pytest_collection_modifyitems(items): - if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") != "1": + if os.getenv("PL_RUN_STANDALONE_TESTS", "0") != "1": return - # filter out non-special tests + # filter out non-standalone tests items[:] = [ item for item in items for marker in item.own_markers - # has `@RunIf(special=True)` - if marker.name == "skipif" and marker.kwargs.get("special") + # has `@RunIf(standalone=True)` + if marker.name == "skipif" and marker.kwargs.get("standalone") ] diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0d2e2a261e775..e506fc2927f7e 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -480,7 +480,7 @@ def test_result_collection_reload_1_gpu_ddp(tmpdir): result_collection_reload(default_root_dir=tmpdir, strategy="ddp", gpus=1) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_result_collection_reload_2_gpus(tmpdir): result_collection_reload(default_root_dir=tmpdir, strategy="ddp", gpus=2) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 5cdf422cf4fdb..4ad6942aa160a 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -65,7 +65,7 @@ def __new__( horovod: bool = False, horovod_nccl: bool = False, skip_windows: bool = False, - special: bool = False, + standalone: bool = False, fairscale: bool = False, fairscale_fully_sharded: bool = False, deepspeed: bool = False, @@ -87,7 +87,7 @@ def __new__( horovod: if Horovod is installed horovod_nccl: if Horovod is installed with NCCL support skip_windows: skip test for Windows platform (typically for some limited torch functionality) - special: running in special mode, outside pytest suit + standalone: Mark the test as standalone, our CI will run it in a separate process. fairscale: if `fairscale` module is required to run the test fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test deepspeed: if `deepspeed` module is required to run the test @@ -146,12 +146,12 @@ def __new__( conditions.append(not _HOROVOD_NCCL_AVAILABLE) reasons.append("Horovod with NCCL") - if special: - env_flag = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") + if standalone: + env_flag = os.getenv("PL_RUN_STANDALONE_TESTS", "0") conditions.append(env_flag != "1") - reasons.append("Special execution") + reasons.append("Standalone execution") # used in tests/conftest.py::pytest_collection_modifyitems - kwargs["special"] = True + kwargs["standalone"] = True if fairscale: conditions.append(not _FAIRSCALE_AVAILABLE) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 663001d08df54..1e8bf40e83319 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -380,7 +380,7 @@ def test_autocast(): lite._precision_plugin.forward_context().__exit__.assert_called() -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multiple_models(): class Lite(LightningLite): def run(self): diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py index bec9339ec8e2f..d4d0ca6e5e9c7 100644 --- a/tests/lite/test_parity.py +++ b/tests/lite/test_parity.py @@ -190,7 +190,7 @@ def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, assert torch.equal(w_pure.cpu(), w_lite.cpu()) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) @pytest.mark.parametrize( "precision, strategy, devices, accelerator", [ diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 35b50acfcef4f..e8db816ed4edc 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -167,7 +167,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_transfer_batch_hook_ddp(tmpdir): """Test custom data are properly moved to the right device using ddp.""" @@ -426,7 +426,7 @@ def _predict_batch(trainer, model, batches): return out -@RunIf(deepspeed=True, min_gpus=1, special=True) +@RunIf(deepspeed=True, min_gpus=1, standalone=True) @pytest.mark.parametrize("automatic_optimization", (True, False)) def test_trainer_model_hook_system_fit_deepspeed(tmpdir, automatic_optimization): _run_trainer_model_hook_system_fit( diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index 67880bec4e474..86c4a5af68b91 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -67,7 +67,7 @@ def configure_optimizers(self): # TODO: Fatal Python error: Bus error @pytest.mark.skip(reason="Fatal Python error: Bus error") -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_sync_batchnorm_ddp(tmpdir): seed_everything(234) set_random_main_port() diff --git a/tests/plugins/environments/torch_elastic_deadlock.py b/tests/plugins/environments/torch_elastic_deadlock.py index ead433200c304..f8a64ba632991 100644 --- a/tests/plugins/environments/torch_elastic_deadlock.py +++ b/tests/plugins/environments/torch_elastic_deadlock.py @@ -7,7 +7,7 @@ from pytorch_lightning.utilities.exceptions import DeadlockDetectedException from tests.helpers.boring_model import BoringModel -if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" and os.getenv("PL_RECONCILE_PROCESS", "0") == "1": +if os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1" and os.getenv("PL_RECONCILE_PROCESS", "0") == "1": class CustomException(Exception): pass diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 8f563f0e410e2..24c04de6604ef 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -190,7 +190,7 @@ def configure_optimizers(self): trainer.fit(model) -@RunIf(min_gpus=2, amp_apex=True, special=True) +@RunIf(min_gpus=2, amp_apex=True, standalone=True) @pytest.mark.parametrize("amp_level", ["O2"]) def test_amp_apex_ddp_fit(amp_level, tmpdir): class CustomBoringModel(BoringModel): diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index c0fab297173e7..6967ea9a12bd7 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -89,7 +89,7 @@ def _assert_layer_fsdp_instance(self) -> None: assert self.layer.module[2].reshard_after_forward is True -@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, special=True) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, standalone=True) def test_fully_sharded_plugin_checkpoint(tmpdir): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" @@ -98,7 +98,7 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=True) +@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, standalone=True) def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" @@ -136,7 +136,7 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.test(ckpt_path=model_path) -@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, special=True) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, standalone=True) def test_fsdp_gradient_clipping_raises(tmpdir): """Test to ensure that an exception is raised when clipping gradients by value with FSDP.""" model = BoringModel() diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 78ae931330307..1aaf89d052686 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -33,7 +33,7 @@ def on_train_start(self) -> None: self.start_cuda_memory = torch.cuda.memory_allocated() -@RunIf(skip_windows=True, min_gpus=2, special=True) +@RunIf(skip_windows=True, min_gpus=2, standalone=True) def test_ddp_with_2_gpus(): """Tests if device is set correctely when training and after teardown for DDPPlugin.""" trainer = Trainer(gpus=2, strategy="ddp", fast_dev_run=True) @@ -64,7 +64,7 @@ def on_train_start(self): self.trainer.training_type_plugin.barrier("barrier after model is wrapped") -@RunIf(min_gpus=4, special=True) +@RunIf(min_gpus=4, standalone=True) @mock.patch("torch.distributed.barrier") def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir): """Test correct usage of barriers when device ids do not start at 0 or are not consecutive.""" diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index efcb089487c5b..7ee46fe0c52c3 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -26,7 +26,7 @@ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD -@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, standalone=True) def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() @@ -46,7 +46,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, standalone=True) def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() @@ -69,7 +69,7 @@ def test_ddp_sgd_comm_hook(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, standalone=True) def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() @@ -93,7 +93,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) +@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, standalone=True) def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): """Test for DDP Spawn FP16 compress hook.""" model = BoringModel() @@ -110,7 +110,7 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" -@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True) +@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True) def test_ddp_post_local_sgd_comm_hook(tmpdir): """Test for DDP post-localSGD hook.""" model = BoringModel() diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 397803e1d8a17..7cca6f6724656 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -203,7 +203,7 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_warn_deepspeed_ignored(tmpdir): class TestModel(BoringModel): def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: @@ -259,7 +259,7 @@ def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None: trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_run_configure_optimizers(tmpdir): """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers.""" @@ -296,7 +296,7 @@ def configure_optimizers(self): _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_config(tmpdir, deepspeed_zero_config): """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers and saves the model weights to load correctly.""" @@ -324,7 +324,7 @@ def on_train_start(self, trainer, pl_module) -> None: trainer.test(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_custom_precision_params(tmpdir): """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.""" @@ -386,7 +386,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None: trainer.fit(model) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu(tmpdir): """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly.""" @@ -402,14 +402,14 @@ def test_deepspeed_multigpu(tmpdir): _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_fp32_works(tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, gpus=1, strategy="deepspeed_stage_3", fast_dev_run=True) trainer.fit(model) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_stage_3_save_warning(tmpdir): """Test to ensure that DeepSpeed Stage 3 gives a warning when saving on rank zero.""" model = BoringModel() @@ -429,7 +429,7 @@ def test_deepspeed_stage_3_save_warning(tmpdir): trainer.save_checkpoint(checkpoint_path) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_single_file(tmpdir): """Test to ensure that DeepSpeed loads from a single file checkpoint.""" model = BoringModel() @@ -538,7 +538,7 @@ def training_step(self, batch, batch_idx): opt.step() -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModel() @@ -551,7 +551,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config): """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModelManualOptim() @@ -600,14 +600,14 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu assert results[0]["test_acc"] > 0.7 -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, and see convergence.""" run_checkpoint_test(tmpdir) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored.""" @@ -634,7 +634,7 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_path) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume training.""" initial_model = ModelParallelClassificationModel() @@ -688,19 +688,19 @@ def on_train_batch_start( trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, where we save the full weights to one file.""" run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir): _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer=False) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_2_accumulated_grad_batches_offload_optimizer(tmpdir): _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer=True) @@ -741,7 +741,7 @@ def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, assert verification_callback.on_train_batch_start_called -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_test(tmpdir): """Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.""" model = ModelParallelBoringModel() @@ -751,7 +751,7 @@ def test_deepspeed_multigpu_test(tmpdir): trainer.test(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model`` correctly converts all parameters to float16 when ``precision=16`` and runs successfully.""" @@ -778,7 +778,7 @@ def on_train_epoch_start(self) -> None: trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_test_rnn(tmpdir): """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when training with certain layers which will crash with explicit partitioning.""" @@ -849,7 +849,7 @@ def _assert_save_model_is_equal(model, tmpdir, trainer): assert torch.equal(orig_param, saved_model_param) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_multigpu_no_schedulers(tmpdir): """Test to ensure ZeRO Stage 3 works with a parallel model and no schedulers.""" model = ModelParallelBoringModelNoSchedulers() @@ -861,7 +861,7 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir): _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_skip_backward_raises(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -873,7 +873,7 @@ def training_step(self, batch, batch_idx): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_setup_train_dataloader(tmpdir): """Test DeepSpeed works when setup is required to call in the DataModule.""" @@ -911,7 +911,7 @@ def test_dataloader(self): @mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_scheduler_step_count(mock_step): """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is set to step.""" @@ -919,7 +919,7 @@ def test_deepspeed_scheduler_step_count(mock_step): @mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_scheduler_step_count_epoch(mock_step): """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is set to epoch.""" @@ -954,7 +954,7 @@ def configure_optimizers(self): assert mock_step.call_count == 1 + (max_epoch * limit_train_batches) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_configure_gradient_clipping(tmpdir): """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in case of deepspeed.""" @@ -975,7 +975,7 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_gradient_clip_by_value(tmpdir): """Test to ensure that an exception is raised when using `gradient_clip_algorithm='value'`.""" model = BoringModel() @@ -989,7 +989,7 @@ def test_deepspeed_gradient_clip_by_value(tmpdir): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_different_accumulate_grad_batches_fails(tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, accumulate_grad_batches={1: 2}, gpus=1, strategy="deepspeed") @@ -999,7 +999,7 @@ def test_different_accumulate_grad_batches_fails(tmpdir): trainer.fit(model) -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_specific_gpu_device_id(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -1035,7 +1035,7 @@ def on_test_batch_start( trainer.test(model) -@RunIf(min_gpus=2, deepspeed=True, special=True, min_torch="1.10.0") +@RunIf(min_gpus=2, deepspeed=True, standalone=True, min_torch="1.10.0") def test_deepspeed_with_meta_device(tmpdir): with init_meta_context(): model = BoringModel() diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index e80b5d9f7621e..8a55633fb143e 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -175,7 +175,7 @@ def test_ddp_sharded_plugin_fit_ckpt_path_gpu_to_cpu(tmpdir): trainer.fit(model, ckpt_path=checkpoint_path) -@RunIf(skip_windows=True, special=True, fairscale=True) +@RunIf(skip_windows=True, standalone=True, fairscale=True) @pytest.mark.parametrize("trainer_kwargs", (dict(num_processes=2), pytest.param(dict(gpus=2), marks=RunIf(min_gpus=2)))) def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): """Test to ensure we can use validate and test without fit.""" @@ -201,7 +201,7 @@ def training_step(self, batch, batch_idx): return {"loss": loss} -@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2) +@RunIf(skip_windows=True, standalone=True, fairscale=True, min_gpus=2) def test_ddp_sharded_plugin_manual_optimization_spawn(tmpdir): # todo (sean): this test has been split out as running both tests using parametrize causes "Address in use" model = ManualBoringModel() @@ -209,7 +209,7 @@ def test_ddp_sharded_plugin_manual_optimization_spawn(tmpdir): trainer.fit(model) -@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2) +@RunIf(skip_windows=True, standalone=True, fairscale=True, min_gpus=2) def test_ddp_sharded_plugin_manual_optimization(tmpdir): model = ManualBoringModel() trainer = Trainer(default_root_dir=tmpdir, strategy="ddp_sharded", fast_dev_run=2, gpus=2) diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 4d18648b6a7f1..126a9a6d1dee6 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -295,7 +295,7 @@ def test_advanced_profiler_cprofile_deepcopy(tmpdir): trainer.fit(model) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): """Ensure that the profiler can be given to the training and default step are properly recorded.""" model = BoringModel() @@ -333,7 +333,7 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler): assert any(f"{local_rank}-validation_step" in f for f in files) -@RunIf(special=True) +@RunIf(standalone=True) @pytest.mark.parametrize("fast_dev_run", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("boring_model_cls", [ManualOptimBoringModel, BoringModel]) def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir): @@ -428,7 +428,7 @@ def look_for_trace(trace_dir): assert look_for_trace(tmpdir) -@RunIf(min_gpus=1, special=True) +@RunIf(min_gpus=1, standalone=True) def test_pytorch_profiler_nested_emit_nvtx(tmpdir): """This test check emit_nvtx is correctly supported.""" profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) diff --git a/tests/special_tests.sh b/tests/standalone_tests.sh similarity index 82% rename from tests/special_tests.sh rename to tests/standalone_tests.sh index 27abaa6cc62e3..49c608d53cfa1 100755 --- a/tests/special_tests.sh +++ b/tests/standalone_tests.sh @@ -15,12 +15,12 @@ set -e # this environment variable allows special tests to run -export PL_RUNNING_SPECIAL_TESTS=1 +export PL_RUN_STANDALONE_TESTS=1 # python arguments defaults='-m coverage run --source pytorch_lightning --append -m pytest --capture=no' -# find tests marked as `@RunIf(special=True)`. done manually instead of with pytest because it is faster -grep_output=$(grep --recursive --word-regexp 'tests' --regexp 'special=True' --include '*.py' --exclude 'tests/conftest.py') +# find tests marked as `@RunIf(standalone=True)`. done manually instead of with pytest because it is faster +grep_output=$(grep --recursive --word-regexp 'tests' --regexp 'standalone=True' --include '*.py' --exclude 'tests/conftest.py') # file paths, remove duplicates files=$(echo "$grep_output" | cut -f1 -d: | sort | uniq) @@ -47,10 +47,10 @@ for i in "${!parametrizations_arr[@]}"; do continue fi - # SPECIAL_PATTERN allows filtering the tests to run when debugging. - # use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those + # STANDALONE_PATTERN allows filtering the tests to run when debugging. + # use as `STANDALONE_PATTERN="foo_bar" ./standalone_tests.sh` to run only those # test with `foo_bar` in their name - if [[ $parametrization != *$SPECIAL_PATTERN* ]]; then + if [[ $parametrization != *STANDALONE_PATTERN* ]]; then report+="Skipped\t$parametrization\n" continue fi @@ -74,7 +74,7 @@ fi # TODO: enable when CI uses torch>=1.9 # test deadlock is properly handled with TorchElastic. -# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 PL_RECONCILE_PROCESS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED") +# LOGS=$(PL_RUN_STANDALONE_TESTS=1 PL_RECONCILE_PROCESS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED") # if [ -z "$LOGS" ]; then # exit 1 # fi diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 0ec61358d9408..6bfbaa9a7bcb1 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -434,7 +434,7 @@ def test_logging_sync_dist_true(tmpdir, devices): assert metrics["bar_3"] == 2 + int(use_multiple_devices) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_logging_sync_dist_true_ddp(tmpdir): """Tests to ensure that the sync_dist flag works with ddp.""" diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index ba4fe915fadb1..dbbb4d9bdffa7 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -840,7 +840,7 @@ def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationD assert not torch.equal(param.cpu().data, param_copy.data) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_step_with_optimizer_closure_with_different_frequencies_ddp(tmpdir): """Tests that `step` works with optimizer_closure and different accumulated_gradient frequency.""" @@ -910,7 +910,7 @@ def dis_closure(): opt_dis.zero_grad() -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir): train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index b2d88becb1ec7..4a99b3318f06f 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -537,7 +537,7 @@ def configure_optimizers(self): trainer.fit(model) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_optimizer_state_on_device(tmpdir): """Test that optimizers that create state initially at instantiation still end up with the state on the GPU.""" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6004d4540a85f..6416ef88fb210 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1462,7 +1462,7 @@ def test_trainer_predict_cpu(tmpdir, datamodule, enable_progress_bar): predict(tmpdir, datamodule=datamodule, enable_progress_bar=enable_progress_bar) -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) @pytest.mark.parametrize( "kwargs", [ @@ -1471,7 +1471,7 @@ def test_trainer_predict_cpu(tmpdir, datamodule, enable_progress_bar): {"strategy": "ddp", "devices": 2}, ], ) -def test_trainer_predict_special(tmpdir, kwargs): +def test_trainer_predict_standalone(tmpdir, kwargs): predict(tmpdir, accelerator="gpu", **kwargs) @@ -1899,7 +1899,7 @@ class CustomException(Exception): pass -@RunIf(min_gpus=2, special=True) +@RunIf(min_gpus=2, standalone=True) def test_ddp_terminate_when_deadlock_is_detected(tmpdir): """Test that DDP kills the remaining processes when only one rank is throwing an exception.""" diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 0ecafa347e574..b7dfd5cbc3311 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -60,7 +60,7 @@ def test_all_gather_ddp_spawn(): torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) -@RunIf(min_gpus=2, skip_windows=True, special=True) +@RunIf(min_gpus=2, skip_windows=True, standalone=True) def test_all_gather_collection(tmpdir): class TestModel(BoringModel): @@ -111,7 +111,7 @@ def training_epoch_end(self, outputs) -> None: assert model.training_epoch_end_called -@RunIf(min_gpus=2, skip_windows=True, special=True) +@RunIf(min_gpus=2, skip_windows=True, standalone=True) def test_all_gather_sync_grads(tmpdir): class TestModel(BoringModel): diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index e85557b4e6056..0f36ada39227d 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -22,7 +22,7 @@ from tests.helpers.runif import RunIf -@RunIf(min_gpus=2, deepspeed=True, special=True) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_collate_checkpoint(tmpdir): """Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file.""" model = BoringModel() diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 581b949d9167f..1f386ac1ce0fe 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -31,7 +31,7 @@ def __init__(self, num_layers: int): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) -@RunIf(special=True, min_torch="1.10.0") +@RunIf(standalone=True, min_torch="1.10.0") def test_init_meta_context(): with init_meta_context(): diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py index 6ef3793b5e0f3..af63bc905bce3 100644 --- a/tests/utilities/test_warnings.py +++ b/tests/utilities/test_warnings.py @@ -21,8 +21,8 @@ from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache -running_special = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" -if running_special: +standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1" +if standalone: stderr = StringIO() # recording From ae53562c97d85f5410e10dbd98bb94e14193e426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Nov 2021 18:49:00 +0100 Subject: [PATCH 45/59] Remove dead code in `TrainingEpochLoop` (#10750) --- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/loops/epoch/training_epoch_loop.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 5dfe3e986e14a..10fcb44ca19a0 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -79,7 +79,7 @@ def on_pretrain_routine_end(self) -> None: - training_start """ - def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index a75ad470c29ef..3a6049b7b23c0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -169,10 +169,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: batch_output = [] else: # hook - response = self.trainer.call_hook("on_batch_start") - if response == -1: - self.batch_progress.increment_processed() - raise StopIteration + self.trainer.call_hook("on_batch_start") # TODO: Update this in v1.7 (deprecation: #9816) model_fx = self.trainer.lightning_module.on_train_batch_start From 31bb6e69caee5f4140dee63d0fdba1d1b85de68a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Nov 2021 19:00:18 +0100 Subject: [PATCH 46/59] Avoid optional instances in Loops (#10735) * Avoid optional instances in Loops * More cleanup --- pytorch_lightning/loops/batch/training_batch_loop.py | 2 +- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 7 ++----- pytorch_lightning/loops/epoch/training_epoch_loop.py | 8 +++----- pytorch_lightning/loops/fit_loop.py | 6 +----- pytorch_lightning/loops/optimization/optimizer_loop.py | 3 --- tests/loops/test_loops.py | 2 +- 6 files changed, 8 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index c1d800c42d853..7ed199e56be13 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -48,7 +48,7 @@ def done(self) -> bool: return len(self._remaining_splits) == 0 def connect( - self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None + self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None ) -> None: if optimizer_loop is not None: self.optimizer_loop = optimizer_loop diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index b7bfc1e0ed8a2..ab9be34a0d49a 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -49,16 +49,13 @@ def __init__(self) -> None: self._num_dataloaders: Optional[int] = None self._dataloader_iter: Optional[Iterator] = None self._data_fetcher: Optional[DataFetcher] = None - self._dataloader_state_dict: Dict[str, Any] = None + self._dataloader_state_dict: Dict[str, Any] = {} @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = None @@ -192,7 +189,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: _reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) - self._dataloader_state_dict = None + self._dataloader_state_dict = {} def _num_completed_batches_reached(self) -> bool: epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3a6049b7b23c0..2a471ab198d1d 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -62,8 +62,8 @@ def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None: self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.batch_loop: Optional[TrainingBatchLoop] = None - self.val_loop: Optional["loops.EvaluationLoop"] = None + self.batch_loop = TrainingBatchLoop() + self.val_loop = loops.EvaluationLoop() self._results = ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] @@ -107,7 +107,7 @@ def done(self) -> bool: def connect( self, - batch_loop: TrainingBatchLoop = None, + batch_loop: Optional[TrainingBatchLoop] = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" @@ -118,8 +118,6 @@ def connect( def reset(self) -> None: """Resets the internal state of the loop for a new run.""" - assert self.batch_loop is not None - assert self.batch_loop.optimizer_loop is not None if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index df6634c963851..4040d08d4f3dd 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -48,7 +48,7 @@ def __init__( self.max_epochs = max_epochs self.min_epochs = min_epochs - self.epoch_loop: Optional[TrainingEpochLoop] = None + self.epoch_loop = TrainingEpochLoop() self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True @@ -128,15 +128,11 @@ def running_loss(self) -> TensorRunningAccum: @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" - assert self.epoch_loop.batch_loop is not None - assert self.epoch_loop.batch_loop.optimizer_loop is not None return self.epoch_loop.batch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" - assert self.epoch_loop.batch_loop is not None - assert self.epoch_loop.batch_loop.optimizer_loop is not None self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index b6bc1c3c25bf9..cdb1317e3ec3a 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -268,9 +268,6 @@ def _run_optimization( # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss # TODO: find proper way to handle updating running loss - assert self.trainer.fit_loop is not None - assert self.trainer.fit_loop.epoch_loop is not None - assert self.trainer.fit_loop.epoch_loop.batch_loop is not None self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss) # untoggle model params diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 3c8912e145305..6338ed00e481d 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -61,7 +61,7 @@ def test_connect_loops_direct(loop_name): trainer = Trainer() - # trainer.loop = loop + # trainer.loop_name = loop setattr(trainer, loop_name, loop) assert loop.trainer is trainer From 78face65e80a1d0be4296c5025de56770b183660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Nov 2021 19:12:21 +0100 Subject: [PATCH 47/59] Improve typing for logging (#10748) Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pyproject.toml | 1 - .../core/mixins/device_dtype_mixin.py | 24 +++++++++++++------ .../connectors/logger_connector/result.py | 17 +++++++------ requirements.txt | 2 +- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c266e0684e974..2471be131c41a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ module = [ "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.data_connector", - "pytorch_lightning.trainer.connectors.logger_connector.result", "pytorch_lightning.trainer.data_loading", "pytorch_lightning.trainer.optimizers", "pytorch_lightning.trainer.supporters", diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e8b122989cd9c..d902958b9bc40 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,16 @@ import torch from torch.nn import Module +try: + from typing_extensions import Self +except ImportError: + # workaround for Python 3.6 and 3.7. + # see https://www.python.org/dev/peps/pep-0673/ + from typing import TypeVar + + Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") + + import pytorch_lightning as pl @@ -47,7 +57,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin": + def to(self, *args: Any, **kwargs: Any) -> Self: """Moves and/or casts the parameters and buffers. This can be called as @@ -110,7 +120,7 @@ def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin": self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin": + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -127,7 +137,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtyp self.__update_properties(device=device) return super().cuda(device=device) - def cpu(self) -> "DeviceDtypeModuleMixin": + def cpu(self) -> Self: """Moves all model parameters and buffers to the CPU. Returns: @@ -136,7 +146,7 @@ def cpu(self) -> "DeviceDtypeModuleMixin": self.__update_properties(device=torch.device("cpu")) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin": + def type(self, dst_type: Union[str, torch.dtype]) -> Self: """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -148,7 +158,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin": self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> "DeviceDtypeModuleMixin": + def float(self) -> Self: """Casts all floating point parameters and buffers to ``float`` datatype. Returns: @@ -157,7 +167,7 @@ def float(self) -> "DeviceDtypeModuleMixin": self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> "DeviceDtypeModuleMixin": + def double(self) -> Self: """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -166,7 +176,7 @@ def double(self) -> "DeviceDtypeModuleMixin": self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> "DeviceDtypeModuleMixin": + def half(self) -> Self: """Casts all floating point parameters and buffers to ``half`` datatype. Returns: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e10360a5fb564..1c27b75854d96 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -211,8 +211,10 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) + # this is defined here only because upstream is missing the type annotation + self._forward_cache: Optional[Any] = None - def update(self, value: _IN_METRIC, batch_size: int) -> None: + def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[override] if self.is_tensor: if not torch.is_floating_point(value): dtype = torch.get_default_dtype() @@ -225,16 +227,17 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: if self.meta.on_step: self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place - - # performance: no need to accumulate on values only logged on_step - if not self.meta.on_epoch: - self.value = self._forward_cache - return + # performance: no need to accumulate on values only logged on_step + if not self.meta.on_epoch: + self.value = self._forward_cache + return # perform accumulation with reduction if self.meta.is_mean_reduction: self.value += value.mean() * batch_size - self.cumulated_batch_size += batch_size + # `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor` + # we could add an assertion, but this is a hot code path + self.cumulated_batch_size += batch_size # type: ignore[operator] elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: diff --git a/requirements.txt b/requirements.txt index 34879d9290acb..94b7151d73641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ tensorboard>=2.2.0 torchmetrics>=0.4.1 pyDeprecate==0.3.1 packaging>=17.0 -typing-extensions +typing-extensions>=4.0.0 From 3089dc3829c6456d74ab95aef06891927519eab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Nov 2021 19:39:09 +0100 Subject: [PATCH 48/59] Improve typing for loops (#10749) * Improve typing for loops * Free memory --- pyproject.toml | 4 --- pytorch_lightning/core/lightning.py | 4 +-- .../loops/dataloader/dataloader_loop.py | 2 +- .../loops/dataloader/evaluation_loop.py | 35 +++++++++---------- .../loops/dataloader/prediction_loop.py | 14 ++++---- .../loops/optimization/optimizer_loop.py | 6 ++-- 6 files changed, 31 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2471be131c41a..9d3e4fd80fa80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,14 +75,10 @@ module = [ "pytorch_lightning.loggers.wandb", "pytorch_lightning.loops.base", "pytorch_lightning.loops.batch.training_batch_loop", - "pytorch_lightning.loops.dataloader.dataloader_loop", - "pytorch_lightning.loops.dataloader.evaluation_loop", - "pytorch_lightning.loops.dataloader.prediction_loop", "pytorch_lightning.loops.epoch.evaluation_epoch_loop", "pytorch_lightning.loops.epoch.prediction_epoch_loop", "pytorch_lightning.loops.epoch.training_epoch_loop", "pytorch_lightning.loops.fit_loop", - "pytorch_lightning.loops.optimization.optimizer_loop", "pytorch_lightning.loops.utilities", "pytorch_lightning.overrides.base", "pytorch_lightning.overrides.data_parallel", diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 89f46949a525c..85afdca05dd5d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -878,7 +878,7 @@ def validation_step_end(self, val_step_outputs): See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: """Called at the end of the validation epoch with the outputs of all validation steps. .. code-block:: python @@ -1056,7 +1056,7 @@ def test_step_end(self, output_results): See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: """Called at the end of a test epoch with the output of all test steps. .. code-block:: python diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 8e0d57c782cab..7a69158ea643a 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -24,7 +24,7 @@ class DataLoaderLoop(Loop): """Base class to loop over all dataloaders.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.dataloader_progress = DataLoaderProgress() diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 6140bd60d6a7f..323a1ded7d01d 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, List, Sequence, Union from deprecate.utils import void from torch.utils.data.dataloader import DataLoader @@ -26,13 +26,13 @@ class EvaluationLoop(DataLoaderLoop): """Loops over all dataloaders for evaluation.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.outputs: List[EPOCH_OUTPUT] = [] self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) - self._max_batches: Optional[Union[int, Sequence[int]]] = None + self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False @property @@ -51,11 +51,12 @@ def num_dataloaders(self) -> int: @property def dataloaders(self) -> Sequence[DataLoader]: """Returns the validation or test dataloaders.""" - if self.trainer.testing: - return self.trainer.test_dataloaders - return self.trainer.val_dataloaders + dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders + if dataloaders is None: + raise RuntimeError("Dataloaders should be available.") + return dataloaders - def connect(self, epoch_loop: EvaluationEpochLoop): + def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] """Connect the evaluation epoch loop with this loop.""" self.epoch_loop = epoch_loop @@ -117,14 +118,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" - outputs = self.outputs - - # free memory - self.outputs = [] - - # with a single dataloader don't pass a 2D list - if len(outputs) > 0 and self.num_dataloaders == 1: - outputs = outputs[0] + outputs, self.outputs = self.outputs, [] # free memory # lightning module method self._evaluation_epoch_end(outputs) @@ -213,7 +207,7 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs) - def _evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None: """Runs ``{validation/test}_epoch_end``""" # inform logger the batch loop has finished self.trainer.logger_connector.epoch_end_reached() @@ -224,15 +218,20 @@ def _evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloader_idx in model model._current_dataloader_idx = None + # with a single dataloader don't pass a 2D list + output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = ( + outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs + ) + if self.trainer.testing: if is_overridden("test_epoch_end", model): model._current_fx_name = "test_epoch_end" - model.test_epoch_end(outputs) + model.test_epoch_end(output_or_outputs) else: if is_overridden("validation_epoch_end", model): model._current_fx_name = "validation_epoch_end" - model.validation_epoch_end(outputs) + model.validation_epoch_end(output_or_outputs) def _on_evaluation_epoch_end(self) -> None: """Runs ``on_{validation/test}_epoch_end`` hook.""" diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index cf40316312107..903fe4b26e3f0 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -13,10 +13,10 @@ class PredictionLoop(DataLoaderLoop): """Loop to run over dataloaders for prediction.""" - def __init__(self): + def __init__(self) -> None: super().__init__() - self.predictions: Optional[List[List[Any]]] = None - self.epoch_batch_indices: Optional[List[List[int]]] = None + self.predictions: List[List[Any]] = [] + self.epoch_batch_indices: List[List[int]] = [] self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access @@ -67,7 +67,7 @@ def dataloaders(self) -> Sequence[DataLoader]: def skip(self) -> bool: return sum(self.max_batches) == 0 - def connect(self, epoch_loop: PredictionEpochLoop): + def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override] """Connect the prediction epoch loop with this loop.""" self.epoch_loop = epoch_loop @@ -77,7 +77,7 @@ def reset(self) -> None: self.predictions = [] self.epoch_batch_indices = [] - def on_run_start(self) -> None: + def on_run_start(self) -> None: # type: ignore[override] """Calls ``_on_predict_start`` hook.""" self._on_predict_start() @@ -94,7 +94,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.predictions.append(dl_predictions) self.epoch_batch_indices.append(dl_batch_indices) - def on_run_end(self) -> _PREDICT_OUTPUT: + def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self._on_predict_epoch_end() self._on_predict_end() @@ -135,7 +135,7 @@ def _on_predict_end(self) -> None: # hook self.trainer.call_hook("on_predict_end") - def _on_predict_model_eval(self): + def _on_predict_model_eval(self) -> None: """Calls ``on_predict_model_eval`` hook.""" model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index cdb1317e3ec3a..c53b1b87a1c89 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -176,8 +176,8 @@ def __init__(self) -> None: self._outputs: _OUTPUTS_TYPE = {} self._skip_backward: bool = False self._batch_idx: int = 0 - self._optimizers: List[Optimizer] = [] - self._indices: List[int] = [] + self._optimizers: Tuple[Optimizer, ...] = tuple() + self._indices: Tuple[int, ...] = tuple() self._hiddens: Optional[Any] = None @property @@ -223,6 +223,8 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor def on_run_end(self) -> _OUTPUTS_TYPE: outputs, self._outputs = self._outputs, {} # free memory + self._indices = tuple() + self._optimizers = tuple() return outputs def _run_optimization( From 88930725ddd54df224c5145abce1d734bbf701fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Nov 2021 20:29:42 +0100 Subject: [PATCH 49/59] Add a custom `PossibleUserWarning` category (#10675) --- docs/source/guides/speed.rst | 5 +++++ pytorch_lightning/trainer/data_loading.py | 7 +++++-- pytorch_lightning/trainer/trainer.py | 4 +++- pytorch_lightning/utilities/warnings.py | 6 +++++- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/docs/source/guides/speed.rst b/docs/source/guides/speed.rst index 04613a89bb35a..3b5a9baaeec9d 100644 --- a/docs/source/guides/speed.rst +++ b/docs/source/guides/speed.rst @@ -153,6 +153,11 @@ For debugging purposes or for dataloaders that load very small datasets, it is d warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + # or to ignore all warnings which could be false positives + from pytorch_lightning.utilities.warnings import PossibleUserWarning + + warnings.filterwarnings("ignore", category=PossibleUserWarning) + Spawn """"" When using ``strategy=ddp_spawn`` or training on TPUs, the way multiple GPUs/TPU cores are used is by calling ``.spawn()`` under the hood. diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 833d0acc4a92e..bfba0229660a6 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -40,6 +40,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import PossibleUserWarning class TrainerDataLoadingMixin(ABC): @@ -109,7 +110,8 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: f"The dataloader, {name}, does not have many workers which may be a bottleneck." " Consider increasing the value of the `num_workers` argument`" f" (try {num_cpus} which is the number of cpus on this machine)" - " in the `DataLoader` init to improve performance." + " in the `DataLoader` init to improve performance.", + category=PossibleUserWarning, ) def _requires_distributed_sampler(self, dataloader) -> bool: @@ -267,7 +269,8 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" - " you want to see logs for the training epoch." + " you want to see logs for the training epoch.", + category=PossibleUserWarning, ) def _reset_eval_dataloader( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9bdd658968b77..b417d40484028 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -96,6 +96,7 @@ LRSchedulerTypeUnion, TRAIN_DATALOADERS, ) +from pytorch_lightning.utilities.warnings import PossibleUserWarning log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -1531,7 +1532,8 @@ def _log_device_info(self) -> None: if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU: rank_zero_warn( - "GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`." + "GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.", + category=PossibleUserWarning, ) if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU: diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 5a01e2a1e941d..75a8d07b01eec 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -27,8 +27,12 @@ def rank_zero_warn(*args, stacklevel: int = 4, **kwargs): _warn(*args, stacklevel=stacklevel, **kwargs) +class PossibleUserWarning(UserWarning): + """Warnings that could be false positives.""" + + class LightningDeprecationWarning(DeprecationWarning): - ... + """Deprecation warnings raised by PyTorch Lightning.""" # enable our warnings From e94aff1c5bed2c616143dc694343c71a2be0b1bd Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 26 Nov 2021 19:33:47 +0000 Subject: [PATCH 50/59] Fault Tolerant: Add support for fault tolerant dataloader validator (#10465) --- CHANGELOG.md | 3 + pytorch_lightning/trainer/data_loading.py | 3 +- pytorch_lightning/utilities/auto_restart.py | 96 ++++++++++++++++- tests/utilities/test_auto_restart.py | 112 +++++++++++++++++++- 4 files changed, 208 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f52369b443164..264e66e278b6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) +- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465)) + + - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index bfba0229660a6..455c2719b124a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -28,7 +28,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _replace_dataloader_init_method, @@ -441,6 +441,7 @@ def request_dataloader( if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") + _validate_fault_tolerant_automatic(dataloader, stage) return dataloader @staticmethod diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 84f0c9decefea..9d26f4a6e0736 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,11 +16,19 @@ from functools import partial, wraps from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import torch -from torch.utils.data import Dataset, get_worker_info, Sampler +from torch.utils.data import ( + BatchSampler, + Dataset, + DistributedSampler, + get_worker_info, + RandomSampler, + Sampler, + SequentialSampler, +) from torch.utils.data.dataloader import ( _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, @@ -370,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str # get current num workers num_workers = getattr(iter_dataloader, "_num_workers", 0) # as `state_dict` are workers dependent, Lightning doesn't support changing - # the `num_workers` for fault tolerant training + # the `num_workers` for Fault-tolerance if state_dict["num_workers"] != num_workers: raise MisconfigurationException( f"The provided `num_workers` {num_workers} doesn't match the one used " @@ -734,13 +742,93 @@ def _patch_dataloader_get_iterators() -> None: def _teardown_dataloader_get_iterators() -> None: """This function is used to restore the DataLoader `get_iterator` with its original one.""" - # cleanup the get_iterator replacement in case of Fault Tolerant Training. + # cleanup the get_iterator replacement in case of Fault-tolerance. get_iterator = getattr(DataLoader, "_ori_get_iterator", None) if get_iterator: DataLoader._get_iterator = get_iterator del DataLoader._ori_get_iterator +def _validate_iterable_dataset(dataloader: DataLoader) -> None: + SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) + + dataset = dataloader.dataset + + if getattr(dataset, "__next__", None) is None: + raise AttributeError( + "Fault-tolerance doesn't support an `IterableDataset` without `__next__` " + "method implemented. Hint: We recommend you to move your logic from `__iter__`" + " inside and rely on a sampler to perform the sample sampling." + ) + + samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} + + if not samplers: + raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.") + + sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS] + + if not sampler: + raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") + + if len(sampler) > 1: + raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") + + if type(sampler[0]) is DistributedSampler and sampler.shuffle: + raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") + elif type(sampler[0]) is not SequentialSampler: + raise TypeError("Only `SequentialSampler` is supported.") + + +def _validate_map_dataset(dataloader: DataLoader) -> None: + SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) + + sampler = getattr(dataloader, "sampler", None) + if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: + raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") + + batch_sampler = getattr(dataloader, "batch_sampler", None) + if batch_sampler is not None and type(batch_sampler) is not BatchSampler: + raise TypeError("Fault-tolerance supports only a `BatchSampler`.") + + if type(sampler) is DistributedSampler and sampler.shuffle: + raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") + elif type(sampler) is RandomSampler: + raise TypeError("Only `SequentialSampler` is supported.") + + +def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: + """This function is used to validate that Fault-tolerance is possible with the user data.""" + if not _FaultTolerantMode.detect_current_mode().is_automatic: + return + + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator + + if isinstance(dataloader, CombinedLoader): + dataloaders = dataloader.loaders + else: + dataloaders = dataloader + + dl_loaders = [] + + def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None: + nonlocal dl_loaders + if isinstance(dataloader, CycleIterator): + dataloader = dataloader.loader + dl_loaders.append(dataloader) + + apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) + + if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: + raise ValueError("Fault-tolerance supports only a single dataloader.") + + for dataloader in dl_loaders: + validator_fn = ( + _validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset + ) + validator_fn(dataloader) + + def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: """This utility collects the state across processes for a collection of state.""" diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 1a479af05aa3f..4c2c440797dd2 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -34,10 +34,12 @@ from torch.utils.data._utils.worker import get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import Sampler import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _collect_states_on_rank_zero_over_collection, @@ -48,6 +50,7 @@ _SingleProcessDataLoaderIterStateful, _SupportsStateDict, _teardown_dataloader_get_iterators, + _validate_fault_tolerant_automatic, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -665,6 +668,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset +@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None) @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.""" @@ -893,6 +897,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_ return model.seen_batches, model.parameters() +# this test will fail `fault_tolerant` don't support multiple datasets. +# this tests works as the dataset is fully deterministic and therefore +# there is not overall between the seeds. +@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.parametrize( "dataset_classes", @@ -1180,6 +1188,108 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_validate_fault_tolerant(tmpdir): + def data(): + return range(10) + + def dataloader(): + return DataLoader(data()) + + _validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING) + + dataloaders = CombinedLoader([dataloader(), dataloader()]) + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle") + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [dataloader(), dataloader()] + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))] + with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))] + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataset = SequentialGetItemDataset(2) + dataloaders = [ + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + ] + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [ + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)), + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + ] + with pytest.raises(ValueError, match="Fault-tolerance supports only a single."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [ + DataLoader(dataset, sampler=RandomSampler(dataset)), + DataLoader(dataset, sampler=SequentialSampler(dataset)), + ] + + with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + + class CustomRandomSampler(RandomSampler): + pass + + dl = DataLoader(data(), sampler=CustomRandomSampler(data())) + with pytest.raises(TypeError, match="RandomSampler"): + _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) + + class CustomBatchSampler(BatchSampler): + pass + + sampler = Sampler(data()) + batch_sampler = CustomBatchSampler(sampler, 2, False) + dl = DataLoader(data(), batch_sampler=batch_sampler) + with pytest.raises(TypeError, match="BatchSampler"): + _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) + + class CustomIterable(IterableDataset): + pass + + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(AttributeError, match="without `__next__` method"): + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) + + class CustomIterable(IterableDataset): + def __next__(self): + return torch.tensor(0) + + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"): + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) + + class CustomIterable(IterableDataset): + def __init__(self): + super().__init__() + self.sampler = CustomRandomSampler(data()) + + def __next__(self): + return torch.tensor(0) + + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(TypeError, match="RandomSampler"): + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) + + dataloaders = [iterable_dataloader, DataLoader(CustomIterable())] + with pytest.raises(TypeError, match="RandomSampler"): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + + def test_rotate_worker_indices(): """This test ensures `worker_id` are rotated properly depending on which one was the latest.""" state_dict = {0: 0, 1: 1} From 81a0a44d8fa82f4f12d80b23a5c88766231da8c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 26 Nov 2021 21:14:11 +0100 Subject: [PATCH 51/59] Improve typing for Lite (#10743) * improve typing in pytorch_lightning/lite * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * include lite again Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 4 +--- pytorch_lightning/lite/lite.py | 16 ++++++++++++++-- pytorch_lightning/lite/wrappers.py | 1 + .../training_type/training_type_plugin.py | 4 ++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9d3e4fd80fa80..168e60e1e2e81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ disable_error_code = "attr-defined" # style choices warn_no_return = "False" -# Changes mypy default to ignore all errors +# Ignore mypy errors for these files # TODO: the goal is for this to be empty [[tool.mypy.overrides]] # the list can be generated with: @@ -63,8 +63,6 @@ module = [ "pytorch_lightning.core.mixins.hparams_mixin", "pytorch_lightning.core.saving", "pytorch_lightning.distributed.dist", - "pytorch_lightning.lite.lite", - "pytorch_lightning.lite.wrappers", "pytorch_lightning.loggers.base", "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.csv_logs", diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 9073f5dd54903..0d292dba54176 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union import torch import torch.nn as nn @@ -201,7 +201,7 @@ def setup_dataloaders( for dataloader in dataloaders ] dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders - return dataloaders + return dataloaders # type: ignore[return-value] def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True @@ -284,6 +284,18 @@ def autocast(self) -> Generator[None, None, None]: with self._precision_plugin.forward_context(): yield + @overload + def to_device(self, obj: nn.Module) -> nn.Module: + ... + + @overload + def to_device(self, obj: Tensor) -> Tensor: + ... + + @overload + def to_device(self, obj: Any) -> Any: + ... + def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on that device. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 908ba06bdb84d..202404ef7162a 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -131,6 +131,7 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: yield from iterator + return for item in iterator: yield move_data_to_device(item, self._device) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7010c0e878dc9..be51cc9f929a4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor @@ -241,7 +241,7 @@ def validation_step_end(self, output): def test_step_end(self, output): return output - def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """Wraps the dataloader if necessary. Args: From 038c151b6ee0754adb66714a62bf32b0a389511f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 26 Nov 2021 21:14:58 +0100 Subject: [PATCH 52/59] Improve typing for plugins (#10742) Co-authored-by: Carlos Mocholi --- pyproject.toml | 3 --- pytorch_lightning/plugins/precision/deepspeed.py | 4 +++- pytorch_lightning/plugins/precision/native_amp.py | 12 ++++++------ .../plugins/precision/precision_plugin.py | 2 +- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 168e60e1e2e81..f219d8f509d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,9 +86,6 @@ module = [ "pytorch_lightning.plugins.environments.lsf_environment", "pytorch_lightning.plugins.environments.slurm_environment", "pytorch_lightning.plugins.environments.torchelastic_environment", - "pytorch_lightning.plugins.precision.deepspeed", - "pytorch_lightning.plugins.precision.native_amp", - "pytorch_lightning.plugins.precision.precision_plugin", "pytorch_lightning.plugins.training_type.ddp", "pytorch_lightning.plugins.training_type.ddp2", "pytorch_lightning.plugins.training_type.ddp_spawn", diff --git a/pytorch_lightning/plugins/precision/deepspeed.py b/pytorch_lightning/plugins/precision/deepspeed.py index 46cf023fc5d32..3a6eb85769559 100644 --- a/pytorch_lightning/plugins/precision/deepspeed.py +++ b/pytorch_lightning/plugins/precision/deepspeed.py @@ -49,7 +49,9 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any deepspeed_engine: DeepSpeedEngine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) - def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: + def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None: + if model is None: + raise ValueError("Please provide the model as input to `backward`.") model.backward(tensor, *args, **kwargs) def optimizer_step( diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index fe4a840b5337c..f6cb28c76c867 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -25,9 +25,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TORCH_GREATER_EQUAL_1_10: - from torch import autocast + from torch import autocast as new_autocast else: - from torch.cuda.amp import autocast + from torch.cuda.amp import autocast as old_autocast class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): @@ -62,7 +62,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) closure_loss = self.scaler.scale(closure_loss) return super().pre_backward(model, closure_loss) - def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: + def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) super()._run_backward(tensor, model, *args, **kwargs) @@ -93,12 +93,12 @@ def optimizer_step( self.scaler.step(optimizer, **kwargs) self.scaler.update() - def autocast_context_manager(self) -> autocast: + def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]: if _TORCH_GREATER_EQUAL_1_10: # the dtype could be automatically inferred but we need to manually set it due to a bug upstream # https://github.com/pytorch/pytorch/issues/67233 - return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) - return autocast() + return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) + return old_autocast() @contextmanager def forward_context(self) -> Generator[None, None, None]: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c4969c9cc805f..3c02d198abd3c 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -147,7 +147,7 @@ def optimizer_step( """Hook to run the optimizer step.""" if isinstance(model, pl.LightningModule): closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) - optimizer.step(closure=closure, **kwargs) + optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg] def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: From c752060712be2ffe2ba79d9520f7e91e240013e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 27 Nov 2021 05:54:45 +0100 Subject: [PATCH 53/59] Consolidate state when retrieving sharded state dict in Lite (#10746) Co-authored-by: thomas chaton --- CHANGELOG.md | 3 +++ pytorch_lightning/lite/wrappers.py | 7 +++++-- tests/lite/test_wrappers.py | 9 +++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 264e66e278b6e..7b8941d6bdd4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) +- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) + + - Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 202404ef7162a..26a76e6ed9ccd 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generator, Iterator, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterator, Optional, Union import torch from torch import nn as nn @@ -42,7 +42,7 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_LiteOptimizer - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._accelerator = accelerator @@ -51,6 +51,9 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: def optimizer(self) -> Optimizer: return self._optimizer + def state_dict(self) -> Dict[str, Tensor]: + return self._accelerator.optimizer_state(self.optimizer) + def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure self._accelerator.optimizer_step( diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index c271d3b3163ed..a732390e1d00a 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -142,6 +142,15 @@ def test_lite_optimizer_wraps(): assert isinstance(lite_optimizer, optimizer_cls) +def test_lite_optimizer_state_dict(): + """Test that the LiteOptimizer calls into the accelerator/strategy to collect the state.""" + optimizer = Mock() + accelerator = Mock() + lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + lite_optimizer.state_dict() + accelerator.optimizer_state.assert_called_with(optimizer) + + def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() From 49d09aa28b4bdde7a272c3e77dc3f13613802672 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 27 Nov 2021 06:28:23 +0100 Subject: [PATCH 54/59] Update changelog after 1.5.3 release (#10744) --- CHANGELOG.md | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b8941d6bdd4d..bb5d26a1071a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,34 +184,33 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076)) +- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) -- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621)) +- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777)) -- Fixed LigtningLite `_wrap_init` popping unexisting keys from DataLoader signature parameters ([#10613](https://github.com/PyTorchLightning/pytorch-lightning/pull/10613)) +- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) -- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610)) +- -- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) +- -- Fixed `Trainer(move_metrics_to_cpu=True)` not moving the evaluation logged results to CPU ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) +## [1.5.3] - 2021-11-24 +### Fixed +- Fixed `ShardedTensor` state dict hook registration to check if torch distributed is available ([#10621](https://github.com/PyTorchLightning/pytorch-lightning/pull/10621)) +- Fixed an issue with `self.log` not respecting a tensor's `dtype` when applying computations ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076)) +- Fixed LigtningLite `_wrap_init` popping unexisting keys from DataLoader signature parameters ([#10613](https://github.com/PyTorchLightning/pytorch-lightning/pull/10613)) +- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610)) +- Fixed an issue that caused Lightning to extract the batch size even though it was set by the user in `LightningModule.log` ([#10408](https://github.com/PyTorchLightning/pytorch-lightning/pull/10408)) +- Fixed `Trainer(move_metrics_to_cpu=True)` not moving the evaluation logged results to CPU ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) - Fixed the `{validation,test}_step` outputs getting moved to CPU with `Trainer(move_metrics_to_cpu=True)` ([#10631](https://github.com/PyTorchLightning/pytorch-lightning/pull/10631)) - - -- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) - - -- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762)) - - -- Fixed TensorBoardLogger `SummaryWriter` not close before spawning the processes ([#10777](https://github.com/PyTorchLightning/pytorch-lightning/pull/10777)) +- Fixed an issue with collecting logged test results with multiple dataloaders ([#10522](https://github.com/PyTorchLightning/pytorch-lightning/pull/10522)) ## [1.5.2] - 2021-11-16 From 4a341e5b24bf5a081506df403996b24d3b485420 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Sun, 28 Nov 2021 22:16:52 +0530 Subject: [PATCH 55/59] Add remote filesystems to docs (#10752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: thomas chaton Co-authored-by: Adrian Wälchli --- docs/source/common/remote_fs.rst | 54 ++++++++++++++++++++++++++++++++ docs/source/index.rst | 1 + 2 files changed, 55 insertions(+) create mode 100644 docs/source/common/remote_fs.rst diff --git a/docs/source/common/remote_fs.rst b/docs/source/common/remote_fs.rst new file mode 100644 index 0000000000000..3ea9213943fc9 --- /dev/null +++ b/docs/source/common/remote_fs.rst @@ -0,0 +1,54 @@ +Remote filesystems +================== + +PyTorch Lightning enables working with data from a variety of filesystems, including local filesystems and several cloud storage providers +such as ``s3`` on AWS, ``gcs`` on Google Cloud, or ``adl`` on Azure. + +This applies to saving and writing checkpoints, as well as for logging. +Working with different filesystems can be accomplished by appending a protocol like "s3:/" to file paths for writing and reading data. + + +.. code-block:: python + + # `default_root_dir` is the default path used for logs and weights + trainer = Trainer(default_root_dir="s3://my_bucket/data/") + trainer.fit(model) + +You could pass custom paths to loggers for logging data. + +.. code-block:: python + + from pytorch_lightning.loggers import TensorBoardLogger + + logger = TensorBoardLogger(save_dir="s3://my_bucket/logs/") + + trainer = Trainer(logger=logger) + trainer.fit(model) + +Additionally, you could also resume training with a checkpoint stored at a remote filesystem. + +.. code-block:: python + + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) + trainer.fit(model, ckpt_path="s3://my_bucket/ckpts/classifier.ckpt") + +PyTorch Lightning uses `fsspec `__ internally to handle all filesystem operations. + +The most common filesystems supported by Lightning are: + +* Local filesystem: ``file://`` - It's the default and doesn't need any protocol to be used. It's installed by default in Lightning. +* Amazon S3: ``s3://`` - Amazon S3 remote binary store, using the library `s3fs `__. Run ``pip install fsspec[s3]`` to install it. +* Google Cloud Storage: ``gcs://`` or ``gs://`` - Google Cloud Storage, using `gcsfs `__. Run ``pip install fsspec[gcs]`` to install it. +* Microsoft Azure Storage: ``adl://``, ``abfs://`` or ``az://`` - Microsoft Azure Storage, using `adlfs `__. Run ``pip install fsspec[adl]`` to install it. +* Hadoop File System: ``hdfs://`` - Hadoop Distributed File System. This uses `PyArrow `__ as the backend. Run ``pip install fsspec[hdfs]`` to install it. + +You could learn more about the available filesystems with: + +.. code-block:: python + + from fsspec.registry import known_implementations + + print(known_implementations) + + +You could also look into :doc:`CheckpointIO plugin <../advanced/checkpoint_io>` for more details on how to customize saving and loading checkpoints. diff --git a/docs/source/index.rst b/docs/source/index.rst index c1b20b958591b..82ca312ae6885 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -67,6 +67,7 @@ PyTorch Lightning common/optimizers advanced/profiler advanced/plugins_registry + common/remote_fs advanced/sequences common/single_gpu advanced/training_tricks From 3f915aaaf949f240febc7823097142ff615f7f1b Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Sun, 28 Nov 2021 22:55:03 +0530 Subject: [PATCH 56/59] Fix reference link to `s3fs` (#10737) --- docs/source/common/lightning_cli.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 7b4680b2d298c..25e35d6623b92 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -290,7 +290,7 @@ Groups of options can also be given as independent config files: When running experiments in clusters it could be desired to use a config which needs to be accessed from a remote location. :class:`~pytorch_lightning.utilities.cli.LightningCLI` comes with `fsspec `_ support which allows reading and writing from many types of remote -file systems. One example is if you have installed the `gcsfs `_ then a config +file systems. One example is if you have installed `s3fs `_ then a config could be stored in an S3 bucket and accessed as: .. code-block:: bash From e1bf54c94489e471205f4cfbe9c866100a5e4817 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 28 Nov 2021 18:58:03 +0100 Subject: [PATCH 57/59] Tune Conda CI timeout and other minor improvements (#10769) --- .github/workflows/ci_test-base.yml | 1 - .github/workflows/ci_test-conda.yml | 8 ++++---- requirements/adjust_versions.py | 7 ++++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 8b2f8b721a37e..c2f1d370e2d1a 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -59,7 +59,6 @@ jobs: - name: Test Package [only] run: | - # NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003 coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Upload pytest test results diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index bc75921dc61fc..fa366e645f1d9 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -17,7 +17,7 @@ jobs: python-version: ["3.8"] # previous to last Python version as that one is already used in test-full pytorch-version: ["1.7", "1.8", "1.9", "1.10"] # nightly: add when there's a release candidate - timeout-minutes: 35 + timeout-minutes: 30 steps: - uses: actions/checkout@v2 @@ -29,7 +29,8 @@ jobs: python ./requirements/adjust_versions.py requirements/extra.txt python ./requirements/adjust_versions.py requirements/examples.txt pip install --requirement requirements/devel.txt --find-links https://download.pytorch.org/whl/nightly/torch_nightly.html - pip install pytest-random-order + # set a per-test timeout of 2.5 minutes to fail sooner. this aids with hanging tests + pip install pytest-timeout pip list - name: Pull checkpoints from S3 @@ -42,8 +43,7 @@ jobs: - name: Tests run: | - # NOTE: run coverage on tests does not propagate failure status for Win, https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pytorch_lightning -m pytest --random-order-seed=2 pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml + coverage run --source pytorch_lightning -m pytest --timeout 150 pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-torch${{ matrix.pytorch-version }}.xml shell: bash -l {0} - name: Upload pytest results diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py index 8295a726e7873..2ec7a177e0824 100644 --- a/requirements/adjust_versions.py +++ b/requirements/adjust_versions.py @@ -83,8 +83,9 @@ def test(): else: requirements_path, torch_version = sys.argv[1], None - with open(requirements_path, "r+") as fp: + with open(requirements_path) as fp: requirements = fp.read() - requirements = main(requirements, torch_version) - print(requirements) # on purpose - to debug + requirements = main(requirements, torch_version) + print(requirements) # on purpose - to debug + with open(requirements_path, "w") as fp: fp.write(requirements) From 724a92b065e3578644adee3aa84d6fcadd5404d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 28 Nov 2021 21:09:30 +0100 Subject: [PATCH 58/59] Mark outputs as protected in the evaluation loops (#10781) Co-authored-by: Rohit Gupta --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 8 ++++---- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 10 ++++------ tests/callbacks/test_callback_hook_outputs.py | 2 +- tests/loops/test_evaluation_loop.py | 2 +- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 323a1ded7d01d..969a038776f94 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -28,10 +28,10 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self) -> None: super().__init__() - self.outputs: List[EPOCH_OUTPUT] = [] self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) + self._outputs: List[EPOCH_OUTPUT] = [] self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False @@ -75,7 +75,7 @@ def reset(self) -> None: """Resets the internal state of the loop.""" self._max_batches = self._get_max_batches() # bookkeeping - self.outputs = [] + self._outputs = [] if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self.dataloaders) @@ -110,7 +110,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) # store batch level output per dataloader - self.outputs.append(dl_outputs) + self._outputs.append(dl_outputs) if not self.trainer.sanity_checking: # indicate the loop has run @@ -118,7 +118,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" - outputs, self.outputs = self.outputs, [] # free memory + outputs, self._outputs = self._outputs, [] # free memory # lightning module method self._evaluation_epoch_end(outputs) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index ab9be34a0d49a..cbaac51ff1d58 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -42,9 +42,9 @@ class EvaluationEpochLoop(Loop): def __init__(self) -> None: super().__init__() - self.outputs: EPOCH_OUTPUT = [] self.batch_progress = BatchProgress() + self._outputs: EPOCH_OUTPUT = [] self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self._dataloader_iter: Optional[Iterator] = None @@ -61,7 +61,7 @@ def reset(self) -> None: self._dl_max_batches = None self._num_dataloaders = None self._data_fetcher = None - self.outputs = [] + self._outputs = [] if not self.restarting: self.batch_progress.reset_on_run() @@ -136,7 +136,7 @@ def advance( # track epoch level outputs if self._should_track_batch_outputs_for_epoch_end() and output is not None: - self.outputs.append(output) + self._outputs.append(output) if self.trainer.move_metrics_to_cpu: # the evaluation step output is not moved as they are not considered "metrics" @@ -149,9 +149,7 @@ def advance( def on_run_end(self) -> EPOCH_OUTPUT: """Returns the outputs of the whole run.""" - outputs = self.outputs - # free memory - self.outputs = [] + outputs, self._outputs = self._outputs, [] # free memory self._dataloader_iter = None self._data_fetcher = None return outputs diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index d55313fde37e6..f7c9321cd0e2e 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -64,7 +64,7 @@ def training_epoch_end(self, outputs) -> None: def test_free_memory_on_eval_outputs(tmpdir): class CB(Callback): def on_epoch_end(self, trainer, pl_module): - assert len(trainer._evaluation_loop.outputs) == 0 + assert not trainer._evaluation_loop._outputs model = BoringModel() diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index d6b2c15553fb9..d553d386205f5 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -116,7 +116,7 @@ def on_test_batch_end(self, outputs, *_): class TestLoop(EvaluationEpochLoop): def on_advance_end(self): # should be empty - assert not self.outputs + assert not self._outputs # sanity check nonlocal did_assert did_assert = True From 97e52619ea753aeec0b37acedd7568182242f8e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 10:58:23 +0100 Subject: [PATCH 59/59] Fix typing in `pl.overrides.data_parallel` (#10796) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 2 -- pytorch_lightning/overrides/base.py | 12 +++++++++- pytorch_lightning/overrides/data_parallel.py | 10 ++++---- .../plugins/training_type/parallel.py | 4 ++-- .../plugins/training_type/sharded.py | 4 ++-- .../plugins/training_type/sharded_spawn.py | 4 ++-- .../training_type/training_type_plugin.py | 4 ++-- tests/overrides/test_data_parallel.py | 3 ++- tests/plugins/test_sharded_plugin.py | 24 ++++++++++--------- 9 files changed, 39 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f219d8f509d37..0e56d3a3dbd04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,6 @@ module = [ "pytorch_lightning.loops.epoch.training_epoch_loop", "pytorch_lightning.loops.fit_loop", "pytorch_lightning.loops.utilities", - "pytorch_lightning.overrides.base", - "pytorch_lightning.overrides.data_parallel", "pytorch_lightning.overrides.distributed", "pytorch_lightning.overrides.fairscale", "pytorch_lightning.plugins.environments.lightning_environment", diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index fc22902495820..d75628c1a81f2 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -14,6 +14,7 @@ from typing import Any, Union import torch +import torch.nn as nn from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel @@ -101,10 +102,19 @@ def on_post_move_to_device(self) -> None: pass -def unwrap_lightning_module(wrapped_model) -> "pl.LightningModule": +def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": + """Recursively unwraps a :class:`~pytorch_lightning.core.lightning.LightningModule` by following the + ``.module`` attributes on the wrapper. + + Raises: + TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped + further. + """ model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = unwrap_lightning_module(model.module) if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)): model = unwrap_lightning_module(model.module) + if not isinstance(model, pl.LightningModule): + raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") return model diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 615f2c04e73d8..fd32619ed818f 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -13,7 +13,7 @@ # limitations under the License. import numbers import warnings -from typing import Any +from typing import Any, Union import torch @@ -23,7 +23,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection -def _ignore_scalar_return_in_dp(): +def _ignore_scalar_return_in_dp() -> None: # Users get confused by this warning so we silence it warnings.filterwarnings( "ignore", @@ -57,12 +57,12 @@ def __init__(self, pl_module: "pl.LightningModule") -> None: super().__init__(pl_module) _ignore_scalar_return_in_dp() - def forward(self, *inputs, **kwargs): + def forward(self, *inputs: Any, **kwargs: Any) -> Any: self.update_replica_device_attributes(inputs) # forward call will redirect to training_step, validation_step, etc. output = super().forward(*inputs, **kwargs) - def output_transform(data: Any): + def output_transform(data: Any) -> Any: data = python_scalar_to_tensor(data, self.module.device) data = unsqueeze_scalar_tensor(data) return data @@ -101,7 +101,7 @@ def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor: ) -def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any: +def python_scalar_to_tensor(data: Any, device: Union[str, torch.device] = torch.device("cpu")) -> Any: """Converts a Python scalar number to a torch tensor and places it on the given device.""" if isinstance(data, numbers.Number): data = torch.tensor([data], device=device) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 07ede1ae4f833..3a05455b87990 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -57,8 +57,8 @@ def on_tpu(self) -> bool: return self.root_device.type == "xla" and _XLA_AVAILABLE @property - def lightning_module(self): - return unwrap_lightning_module(self._model) + def lightning_module(self) -> Optional["pl.LightningModule"]: + return unwrap_lightning_module(self._model) if self._model is not None else None @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index e7f57e9c92791..280d38bc839f1 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -101,13 +101,13 @@ def _optim_state_dict(self, optimizer): return optimizer.state_dict() @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPShardedPlugin` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) - return unwrap_lightning_module_sharded(self._model) + return unwrap_lightning_module_sharded(self._model) if self._model is not None else None def pre_backward(self, closure_loss: torch.Tensor) -> None: pass diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 12c06b9dde541..9f83f0261c3ec 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -101,13 +101,13 @@ def _optim_state_dict(self, optimizer): return optimizer.state_dict() @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPSpawnShardedPlugin` requires `fairscale` to be installed." " Install it by running `pip install fairscale`." ) - return unwrap_lightning_module_sharded(self._model) + return unwrap_lightning_module_sharded(self._model) if self._model is not None else None def pre_backward(self, closure_loss: torch.Tensor) -> None: pass diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index be51cc9f929a4..b8244b9c2e165 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -177,9 +177,9 @@ def model(self, new_model: Optional[Module]) -> None: self._model = new_model @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" - return unwrap_lightning_module(self._model) + return unwrap_lightning_module(self._model) if self._model is not None else None @property def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 62a497b3106d1..1e0003486a06b 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -42,7 +42,8 @@ ) def test_lightning_wrapper_module_methods(wrapper_class, stage): """Test that the LightningWrapper redirects .forward() to the LightningModule methods.""" - pl_module = MagicMock() + pl_module = Mock(spec=LightningModule) + pl_module.trainer = Mock() wrapped_module = wrapper_class(pl_module) batch = torch.rand(5) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 8a55633fb143e..f6b58692aa221 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,5 +1,6 @@ import os from unittest import mock +from unittest.mock import Mock import pytest import torch @@ -256,14 +257,14 @@ def test_configure_ddp(tmpdir): def test_custom_kwargs_sharded(tmpdir, cls): """Tests to ensure that if custom kwargs are passed, they are set correctly.""" plugin = cls(reduce_fp16=True) - + plugin.model = Mock(spec=LightningModule) + plugin.model.trainer = Mock() class_name = "sharded" if isinstance(plugin, DDPShardedPlugin) else "sharded_spawn" - with mock.patch.object(plugin, "_model", autospec=True): - with mock.patch( - f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True - ) as mock_sharded: - plugin.configure_ddp() + with mock.patch( + f"pytorch_lightning.plugins.training_type.{class_name}.ShardedDataParallel", autospec=True + ) as mock_sharded: + plugin.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] @@ -277,12 +278,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe """Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs.""" plugin = DDPShardedPlugin(**params) plugin.num_nodes = num_nodes + plugin.model = Mock(spec=LightningModule) + plugin.model.trainer = Mock() - with mock.patch.object(plugin, "_model", autospec=True): - with mock.patch( - "pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True - ) as mock_sharded: - plugin.configure_ddp() + with mock.patch( + "pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True + ) as mock_sharded: + plugin.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_buffer_size" in kwargs