Skip to content

PyTorch Lightning 1.7: Apple Silicon support, Native FSDP, Collaborative training, and multi-GPU support with Jupyter notebooks

Compare
Choose a tag to compare
@carmocca carmocca released this 02 Aug 16:21
· 3194 commits to master since this release
d2c086b

The core team is excited to announce the release of PyTorch Lightning 1.7 ⚡

PyTorch Lightning 1.7 is the culmination of work from 106 contributors who have worked on features, bug-fixes, and documentation for a total of over 492 commits since 1.6.0.

Highlights

Apple Silicon Support

For those using PyTorch 1.12 on M1 or M2 Apple machines, we have created the MPSAccelerator. MPSAccelerator enables accelerated GPU training on Apple’s Metal Performance Shaders (MPS) as a backend process.


NOTE

Support for this accelerator is currently marked as experimental in PyTorch. Because many operators are still missing, you may run into a few rough edges.


# Selects the accelerator
trainer = pl.Trainer(accelerator="mps")

# Equivalent to
from pytorch_lightning.accelerators import MPSAccelerator
trainer = pl.Trainer(accelerator=MPSAccelerator())

# Defaults to "mps" when run on M1 or M2 Apple machines
# to avoid code changes when switching computers
trainer = pl.Trainer(accelerator="gpu")

Native Fully Sharded Data Parallel Strategy

PyTorch 1.12 also added native support for Fully Sharded Data Parallel (FSDP). Previously, PyTorch Lightning enabled this by using the fairscale project. You can now choose between both options.


NOTE

Support for this strategy is marked as beta in PyTorch.


# Native PyTorch implementation
trainer = pl.Trainer(strategy="fsdp_native")

# Equivalent to
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
trainer = pl.Trainer(strategy=DDPFullyShardedNativeStrategy())

# For reference, FairScale's implementation can be used with
trainer = pl.Trainer(strategy="fsdp")

A Collaborative Training strategy using Hivemind

Collaborative Training solves the need for top-tier multi-GPU servers by allowing you to train across unreliable machines such as local ones or even preemptible cloud compute across the Internet.

Under the hood, we use Hivemind. This provides de-centralized training across the Internet.

from pytorch_lightning.strategies import HivemindStrategy

trainer = pl.Trainer(
    strategy=HivemindStrategy(target_batch_size=8192), 
    accelerator="gpu", 
    devices=1
)

For more information, check out the docs.

Distributed support in Jupyter Notebooks

So far, the only multi-GPU strategy supported in Jupyter notebooks (including Grid.ai, Google Colab, and Kaggle, for example) has been the Data-Parallel (DP) strategy (strategy="dp"). DP, however, has several limitations that often obstruct users' workflows. It can be slow, it's incompatible with TorchMetrics, it doesn't persist state changes on replicas, and it's difficult to use with non-primitive input- and output structures.

In this release, we've added support for Distributed Data Parallel in Jupyter notebooks using the fork mechanism to address these shortcomings. This is only available for MacOS and Linux (sorry Windows!).


NOTE

This feature is experimental.


This is how you use multi-device in notebooks now:

# Train on 2 GPUs in a Jupyter notebook
trainer = pl.Trainer(accelerator="gpu", devices=2)

# Can be set explicitly
trainer = pl.Trainer(accelerator="gpu", devices=2, strategy="ddp_notebook")

# Can also be used in non-interactive environments
trainer = pl.Trainer(accelerator="gpu", devices=2, strategy="ddp_fork")

By default, the Trainer detects the interactive environment and selects the right strategy for you. Learn more in the full documentation.

Versioning of "last" checkpoints

If a run is configured to save to the same directory as a previous run and ModelCheckpoint(save_last=True) is enabled, the "last" checkpoint is now versioned with a simple -v1 suffix to avoid overwriting the existing "last" checkpoint. This mimics the behaviour for checkpoints that monitor a metric.

Automatically reload the "last" checkpoint

In certain scenarios, like when running in a cloud spot instance with fault-tolerant training enabled, it is useful to load the latest available checkpoint. It is now possible to pass the string ckpt_path="last" in order to load the latest available checkpoint from the set of existing checkpoints.

trainer = Trainer(...)
trainer.fit(..., ckpt_path="last")

Validation every N batches across epochs

In some cases, for example iteration based training, it is useful to run validation after every N number of training batches without being limited by the epoch boundary. Now, you can enable validation based on total training batches.

trainer = Trainer(..., val_check_interval=N, check_val_every_n_epoch=None)
trainer.fit(...)

For example, given 5 epochs of 10 batches, setting N=25 would run validation in the 3rd and 5th epoch.

CPU stats monitoring

PyTorch Lightning provides the DeviceStatsMonitor callback to monitor the stats of the hardware currently used. However, users often also want to monitor the stats of other hardware. In this release, we have added an option to additionally monitor CPU stats:

from pytorch_lightning.callbacks import DeviceStatsMonitor

# Log both CPU stats and GPU stats
trainer = pl.Trainer(callbacks=DeviceStatsMonitor(cpu_stats=True), accelerator="gpu")

# Log just the GPU stats
trainer = pl.Trainer(callbacks=DeviceStatsMonitor(cpu_stats=False), accelerator="gpu")

# Equivalent to `DeviceStatsMonitor()`
trainer = pl.Trainer(callbacks=DeviceStatsMonitor(cpu_stats=True), accelerator="cpu")

The CPU stats are gathered using the psutil package.

Automatic distributed samplers

It is now possible to use custom samplers in a distributed environment without the need to set replace_ddp_sampler=False and wrap your sampler manually with the DistributedSampler.

Inference mode support

PyTorch 1.9 introduced torch.inference_mode, which is a faster alternative for torch.no_grad. Lightning will now use inference_mode wherever possible during evaluation.

Support for warn-level determinism

In Pytorch 1.11, operations that do not have a deterministic implementation can be set to throw a warning instead of an error when ran in deterministic mode. This is now supported by our Trainer:

trainer = pl.Trainer(deterministic="warn")

LightningCLI improvements

After the latest updates to jsonargparse, the library supporting the LightningCLI, there's now complete support for shorthand notation. This includes automatic support for shorthand notation to all arguments, not just the ones that are part of the registries, plus support inside configuration files.

+ # pytorch_lightning==1.7.0
  trainer:
  callbacks:
-   - class_path: pytorch_lightning.callbacks.EarlyStopping
+   - class_path: EarlyStopping
      init_args:
        monitor: "loss"

A header with the version that generated the config is now included.

All subclasses for a given base class can be specified by name, so there's no need to explicitly register them. The only requirement is that the module where the subclass is defined is imported prior to parsing.

from pytorch_lightning.cli import LightningCLI
import my_code.models
import my_code.optimizers

cli = LightningCLI()
# Now use any of the classes:
# python trainer.py fit --model=Model1 --optimizer=CustomOptimizer

The new version renders the registries and the auto_registry flag, introduced in 1.6.0, unnecessary, so we have deprecated them.

Support was also added for list appending; for example, to add a callback to an existing list that might be already configured:

$ python trainer.py fit \
-   --trainer.callbacks=EarlyStopping \
+   --trainer.callbacks+=EarlyStopping \
    --trainer.callbacks.patience=5 \
-   --trainer.callbacks=LearningRateMonitor \
+   --trainer.callbacks+=LearningRateMonitor \
    --trainer.callbacks.logging_interval=epoch

Callback registration through entry points

Entry Points are an advanced feature in Python's setuptools that allow packages to expose metadata to other packages. In Lightning, we allow an arbitrary package to include callbacks that the Lightning Trainer can automatically use when installed, without you having to manually add them to the Trainer. This is useful in production environments where it is common to provide specialized monitoring and logging callbacks globally for every application.

A setup.py file for a callbacks plugin package could look something like this:

from setuptools import setup

setup(
    name="my-package",
    version="0.0.1",
    entry_points={
        # Lightning will look for this key here in the environment:
        "pytorch_lightning.callbacks_factory": [
            "monitor_callbacks=factories:my_custom_callbacks_factory"
        ]
    },
)

Read more about callback entry points in our docs.

Rank-zero only EarlyStopping messages

Our EarlyStopping callback implementation, by default, logs the stopping messages on every rank when it's run in a distributed environment. This was done in case the monitored values were not synchronized. However, some users found this verbose. To avoid this, you can now set a flag:

from pytorch_lightning.callbacks import EarlyStopping

trainer = pl.Trainer(callbacks=EarlyStopping(..., log_rank_zero_only=True))

A base Checkpoint class for extra customization

If you want to customize ModelCheckpoint callback, without all the extra functionality this class provides, this release provides an empty class Checkpoint for easier inheritance. In all internal code, the check is made against the Checkpoint class in order to ensure everything works properly for custom classes.

Validation now runs in overfitting mode

Setting overfit_batches=N, now enables validation and runs N number of validation batches during trainer.fit.

# Uses 1% of each train & val set
trainer = Trainer(overfit_batches=0.01)

# Uses 10 batches for each train & val set
trainer = Trainer(overfit_batches=10)

Device Stats Monitoring support for HPUs

DeviceStatsMonitor callback can now be used to automatically monitor and log device stats during the training stage with Habana devices.

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor

device_stats = DeviceStatsMonitor()
trainer = Trainer(accelerator="hpu", callbacks=[device_stats])

New Hooks

LightningDataModule.load_from_checkpoint

Now, hyper-parameters from LightningDataModule save to checkpoints and reload when training is resumed. And just like you use LightningModule.load_from_checkpoint to load a model using a checkpoint filepath, you can now load LightningDataModule using the same hook.

# Lad weights without mapping ...
datamodule = MyLightningDataModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# Or load weights and hyperparameters from separate files.
datamodule = MyLightningDataModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# Override some of the params with new values
datamodule = MyLightningDataModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    batch_size=32,
    num_workers=10,
)

Experimental Features

ServableModule and its Servable Module Validator Callback

When serving models in production, it generally is a good pratice to ensure that the model can be served and optimzed before starting training to avoid wasting money.

To do so, you can import a ServableModule (an nn.Module) and add it as an extra base class to your base model as follows:

from pytorch_lightning import LightningModule
from pytorch_lightning.serve import ServableModule

class ProductionReadyModel(LightningModule, ServableModule):
    ...

To make your model servable, you would need to implement three hooks:

  • configure_payload: Describe the format of the payload (data sent to the server).
  • configure_serialization: Describe the functions used to convert the payload to tensors (de-serialization) and tensors to payload (serialization)
  • serve_step: The method used to transform the input tensors to a dictionary of prediction tensors.
from pytorch_lightning.serve import ServableModule, ServableModuleValidator

class ProductionReadyModel(LitModule, ServableModule):
    def configure_payload(self):
        # 1: Access the train dataloader and load a single sample.
        image, _ = self.trainer.train_dataloader.loaders.dataset[0]

        # 2: Convert the image into a PIL Image to bytes and encode it with base64
        pil_image = T.ToPILImage()(image)
        buffered = BytesIO()
        pil_image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("UTF-8")

        payload = {"body": {"x": img_str}}
        return payload

    def configure_serialization(self):
        deserializers = {"x": Image(224, 224).deserialize}
        serializers = {"output": Top1().serialize}
        return deserializers, serializers

    def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {"output": self.model(x)}

Finally, add the ServableModuleValidator callback to the Trainer to validate the model is servable on_train_start. This uses a FastAPI server.

pl_module = ProductionReadyModel()
trainer = Trainer(..., callbacks=[ServableModuleValidator()])
trainer.fit(pl_module)

Have a look at the full example here.

Asynchronous Checkpointing

You can now save checkpoints asynchronously using the AsyncCheckpointIO plugin without blocking your training process. To enable this, you can pass a AsyncCheckpointIO plugin to the Trainer.

from pytorch_lightning.plugins.io import AsyncCheckpointIO

trainer = Trainer(plugins=[AsyncCheckpointIO()])

Have a look at the full example here.

Backward Incompatible Changes

This section outlines notable changes that are not backward compatible with previous versions. The full list of changes and removals can be found in the CHANGELOG below.

Removed support for the DDP2 strategy

The DDP2 strategy, previously known as the DDP2 plugin, has been part of Lightning since its inception. Due to both the technical challenges in maintaining the plugin after PyTorch's removal of the multi-device support in DistributedDataParallel, as well as a general lack of interest, we have decided to retire the strategy entirely.

Do not force metric synchronization on epoch end

In previous versions, metrics logged inside epoch-end hooks were forcefully synced. This makes the sync_dist flag irrelevant and causes communication overhead that might be undesired. In this release, we've removed this behaviour and instead warn the user that synchronization might be desired.

Deprecations

API Removal version Alternative
Import pytorch_lightning.loggers.base.LightningLoggerBase 1.9 pytorch_lightning.loggers.logger.Logger
Import pytorch_lightning.callbacks.base.Callback 1.9 pytorch_lightning.callbacks.callback.Callback
Import pytorch_lightning.core.lightning.LightningModule 1.9 pytorch_lightning.core.module.LightningModule
Import pytorch_lightning.loops.base.Loop 1.9 pytorch_lightning.loops.loop.Loop
Import pytorch_lightning.profiler 1.9 pytorch_lightning.profilers
Arguments Trainer(num_processes=..., gpus=..., tpu_cores=..., ipus=...) 2.0 Trainer(accelerator=..., devices=...)
Argument LightningCLI(seed_everything_default=None) 1.9 LightningCLI(seed_everything_default=False)
Method Trainer.reset_train_val_dataloaders() 1.9 Trainer.reset_{train,val}_dataloader
Import pytorch_lightning.utilities.cli module 1.9 pytorch_lightning.cli
Objects pytorch_lightning.utilities.cli.{OPTIMIZER,LR_SCHEDULER,MODEL,DATAMODULE,CALLBACK,LOGGER}_REGISTRY 1.9 Not necessary anymore
Argument LightningCLI(auto_registry=...) 1.9 Not necessary anymore
Argument Trainer(strategy="ddp2") and class pytorch_lightning.strategies.DDP2Strategy 1.8 No longer supported

CHANGELOG

Added
  • Added ServableModule and its associated callback called ServableModuleValidator to ensure the model can served (#13614)
  • Converted validation loop config warnings to PossibleUserWarning (#13377)
  • Added a flag named log_rank_zero_only to EarlyStopping to disable logging to non-zero rank processes (#13233)
  • Added support for reloading the last checkpoint saved by passing ckpt_path="last" (#12816)
  • Added LightningDataModule.load_from_checkpoint to support loading datamodules directly from checkpoint (#12550)
  • Added a friendly error message when attempting to call Trainer.save_checkpoint() without a model attached (#12772)
  • Added a friendly error message when attempting to use DeepSpeedStrategy on unsupported accelerators (#12699)
  • Enabled torch.inference_mode for evaluation and prediction (#12715)
  • Added support for setting val_check_interval to a value higher than the amount of training batches when check_val_every_n_epoch=None (#11993)
  • Include the pytorch_lightning version as a header in the CLI config files (#12532)
  • Added support for Callback registration through entry points (#12739)
  • Added support for Trainer(deterministic="warn") to warn instead of fail when a non-deterministic operation is encountered (#12588)
  • Added profiling to the loops' dataloader __next__ calls (#12124)
  • Hivemind Strategy
    • Added CollaborativeStrategy (#12842)
    • Renamed CollaborativeStrategy to HivemindStrategy (#13388)
    • Removed unnecessary endpoint logic, renamed collaborative to hivemind (#13392)
  • Include a version suffix for new "last" checkpoints of later runs in the same directory (#12902)
  • Show a better error message when a Metric that does not return a Tensor is logged (#13164)
  • Added missing predict_dataset argument in LightningDataModule.from_datasets to create predict dataloaders (#12942)
  • Added class name prefix to metrics logged by DeviceStatsMonitor (#12228)
  • Automatically wrap custom samplers under a distributed environment by using DistributedSamplerWrapper (#12959)
  • Added profiling of LightningDataModule hooks (#12971)
  • Added Native FSDP Strategy (#12447)
  • Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance (#12938)
  • Added Checkpoint class to inherit from (#13024)
  • Added CPU metric tracking to DeviceStatsMonitor (#11795)
  • Added teardown() method to Accelerator (#11935)
  • Added support for using custom Trainers that don't include callbacks using the CLI (#13138)
  • Added a timeout argument to DDPStrategy and DDPSpawnStrategy. (#13244, #13383)
  • Added XLAEnvironment cluster environment plugin (#11330)
  • Added logging messages to notify when FitLoop stopping conditions are met (#9749)
  • Added support for calling unknown methods with DummyLogger (#13224
  • Added support for recursively setting the Trainer reference for ensembles of LightningModules (#13638
  • Added Apple Silicon Support via MPSAccelerator (#13123)
  • Added support for DDP Fork (#13405)
  • Added support for async checkpointing (#13658)
  • Added support for HPU Device stats monitor (#13819)
Changed
  • accelerator="gpu" now automatically selects an available GPU backend (CUDA and MPS currently) (#13642)
  • Enable validation during overfitting (#12527)
  • Added dataclass support to extract_batch_size (#12573)
  • Changed checkpoints save path in the case of one logger and user-provided weights_save_path from weights_save_path/name/version/checkpoints to weights_save_path/checkpoints (#12372)
  • Changed checkpoints save path in the case of multiple loggers and user-provided weights_save_path from weights_save_path/name1_name2/version1_version2/checkpoints to weights_save_path/checkpoints (#12372)
  • Marked swa_lrs argument in StochasticWeightAveraging callback as required (#12556)
  • LightningCLI's shorthand notation changed to use jsonargparse native feature (#12614)
  • LightningCLI changed to use jsonargparse native support for list append (#13129)
  • Changed seed_everything_default argument in the LightningCLI to type Union[bool, int]. If set to True a seed is automatically generated for the parser argument --seed_everything. (#12822, #13110)
  • Make positional arguments required for classes passed into the add_argparse_args function. (#12504)
  • Raise an error if there are insufficient training batches when using a float value of limit_train_batches (#12885)
  • DataLoader instantiated inside a *_dataloader hook will not set the passed arguments as attributes anymore (#12981)
  • When a multi-element tensor is logged, an error is now raised instead of silently taking the mean of all elements (#13164)
  • The WandbLogger will now use the run name in the logs folder if it is provided, and otherwise the project name (#12604)
  • Enabled using any Sampler in distributed environment in Lite (#13646)
  • Raised a warning instead of forcing sync_dist=True on epoch end (13364)
  • Updated val_check_interval(int) to consider total train batches processed instead of _batches_that_stepped for validation check during training (#12832
  • Updated Habana Accelerator's auto_device_count, is_available & get_device_name methods based on the latest torch habana package (#13423)
  • Disallowed using BatchSampler when running on multiple IPUs (#13854)
Deprecated
  • Deprecated pytorch_lightning.accelerators.gpu.GPUAccelerator in favor of pytorch_lightning.accelerators.cuda.CUDAAccelerator (#13636)
  • Deprecated pytorch_lightning.loggers.base.LightningLoggerBase in favor of pytorch_lightning.loggers.logger.Logger, and deprecated pytorch_lightning.loggers.base in favor of pytorch_lightning.loggers.logger (#120148)
  • Deprecated pytorch_lightning.callbacks.base.Callback in favor of pytorch_lightning.callbacks.callback.Callback (#13031)
  • Deprecated num_processes, gpus, tpu_cores, and ipus from the Trainer constructor in favor of using the accelerator and devices arguments (#11040)
  • Deprecated setting LightningCLI(seed_everything_default=None) in favor of False (#12804).
  • Deprecated pytorch_lightning.core.lightning.LightningModule in favor of pytorch_lightning.core.module.LightningModule (#12740)
  • Deprecated pytorch_lightning.loops.base.Loop in favor of pytorch_lightning.loops.loop.Loop (#13043)
  • Deprecated Trainer.reset_train_val_dataloaders() in favor of Trainer.reset_{train,val}_dataloader (#12184)
  • Deprecated LightningCLI's registries in favor of importing the respective package (#13221)
  • Deprecated public utilities in pytorch_lightning.utilities.cli.LightningCLI in favor of equivalent copies in pytorch_lightning.cli.LightningCLI (#13767)
  • Deprecated pytorch_lightning.profiler in favor of pytorch_lightning.profilers (#12308)
Removed
  • Removed deprecated IndexBatchSamplerWrapper.batch_indices (#13565)
  • Removed the deprecated LightningModule.add_to_queue and LightningModule.get_from_queue method (#13600)
  • Removed deprecated pytorch_lightning.core.decorators.parameter_validation from decorators (#13514)
  • Removed the deprecated Logger.close method (#13149)
  • Removed the deprecated weights_summary argument from the Trainer constructor (#13070)
  • Removed the deprecated flush_logs_every_n_steps argument from the Trainer constructor (#13074)
  • Removed the deprecated process_position argument from the Trainer constructor (13071)
  • Removed the deprecated checkpoint_callback argument from the Trainer constructor (#13027)
  • Removed the deprecated on_{train,val,test,predict}_dataloader hooks from the LightningModule and LightningDataModule (#13033)
  • Removed the deprecated TestTubeLogger (#12859)
  • Removed the deprecated pytorch_lightning.core.memory.LayerSummary and pytorch_lightning.core.memory.ModelSummary (#12593)
  • Removed the deprecated summarize method from the LightningModule (#12559)
  • Removed the deprecated model_size property from the LightningModule class (#12641)
  • Removed the deprecated stochastic_weight_avg argument from the Trainer constructor (#12535)
  • Removed the deprecated progress_bar_refresh_rate argument from the Trainer constructor (#12514)
  • Removed the deprecated prepare_data_per_node argument from the Trainer constructor (#12536)
  • Removed the deprecated pytorch_lightning.core.memory.{get_gpu_memory_map,get_memory_profile} (#12659)
  • Removed the deprecated terminate_on_nan argument from the Trainer constructor (#12553)
  • Removed the deprecated XLAStatsMonitor callback (#12688)
  • Remove deprecated pytorch_lightning.callbacks.progress.progress (#12658)
  • Removed the deprecated dim and size arguments from the LightningDataModule constructor(#12780)
  • Removed the deprecated train_transforms argument from the LightningDataModule constructor(#12662)
  • Removed the deprecated log_gpu_memory argument from the Trainer constructor (#12657)
  • Removed the deprecated automatic logging of GPU stats by the logger connector (#12657)
  • Removed deprecated GPUStatsMonitor callback (#12554)
  • Removed support for passing strategy names or strategy instances to the accelerator Trainer argument (#12696)
  • Removed support for passing strategy names or strategy instances to the plugins Trainer argument (#12700)
  • Removed the deprecated val_transforms argument from the LightningDataModule constructor (#12763)
  • Removed the deprecated test_transforms argument from the LightningDataModule constructor (#12773)
  • Removed deprecated Trainer(max_steps=None) (#13591)
  • Removed deprecated dataloader_idx argument from on_train_batch_start/end hooks Callback and LightningModule (#12769, #12977)
  • Removed deprecated get_progress_bar_dict property from LightningModule (#12839)
  • Removed sanity check for multi-optimizer support with habana backends (#13217)
  • Removed the need to explicitly load habana module (#13338)
  • Removed the deprecated Strategy.post_dispatch() hook (#13461)
  • Removed deprecated pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor.lr_sch_names (#13353)
  • Removed deprecated Trainer.slurm_job_id in favor of SLURMEnvironment.job_id (#13459)
  • Removed support for the DDP2Strategy (#12705)
  • Removed deprecated LightningDistributed (#13549)
  • Removed deprecated ClusterEnvironment properties master_address and master_port in favor of main_address and main_port (#13458)
  • Removed deprecated ClusterEnvironment methods KubeflowEnvironment.is_using_kubelfow(), LSFEnvironment.is_using_lsf() and TorchElasticEnvironment.is_using_torchelastic() in favor of the detect() method (#13458)
  • Removed deprecated Callback.on_keyboard_interrupt (#13438)
  • Removed deprecated LightningModule.on_post_move_to_device (#13548)
  • Removed TPUSpawnStrategy.{tpu_local_core_rank,tpu_global_core_rank} attributes in favor of TPUSpawnStrategy.{local_rank,global_rank} (#11163)
  • Removed SingleTPUStrategy.{tpu_local_core_rank,tpu_global_core_rank} attributes in favor of SingleTPUStrategy.{local_rank,global_rank}(#11163)
Fixed
  • Improved support for custom DataLoaders when instantiated in *_dataloader hook (#12981)
  • Allowed custom BatchSamplers when instantiated in *_dataloader hook #13640)
  • Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad (#13014)
  • The model wrapper returned by LightningLite.setup() now properly supports pass-through when looking up attributes (#12597)
  • Fixed issue where the CLI fails with certain torch objects (#13153)
  • Fixed LightningCLI signature parameter resolving for some lightning classes (#13283)
  • Fixed Model Summary when using DeepSpeed Stage 3 (#13427)
  • Fixed pytorch_lightning.utilities.distributed.gather_all_tensors to handle tensors of different dimensions (#12630)
  • Fixed the input validation for the accelerator Trainer argument when passed as a string (#13417)
  • Fixed Trainer.predict(return_predictions=False) to track prediction's batch_indices (#13629)
  • Fixed and issue that prevented setting a custom CheckpointIO plugin with strategies (#13785)
  • Fixed main progress bar counter when val_check_interval=int and check_val_every_n_epoch=None (#12832
  • Improved support for custom ReduceLROnPlateau scheduler if reduce_on_plateau is set by the user in scheduler config (#13838)
  • Used global_step while restoring logging step for old checkpoints (#13645)
  • When training with precision=16 on IPU, the cast has been moved off the IPU onto the host, making the copies from host to IPU cheaper (#13880)
  • Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion (#13845)
  • Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible (#13845)
  • Fixed an issue causing deterministic algorighms and other globals to get reset in spawned processes (#13921)
  • Fixed default amp_level for DeepSpeedPrecisionPlugin to O2 (#13897)
  • Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) (#13973)
  • Fixed TQDMProgressBar reset and update to show correct time estimation (2/2) (#13962)

Full commit list: 1.6.0...1.7.0

Contributors

Veteran

@akashkw @akihironitta @aniketmaurya @awaelchli @Benjamin-Etheredge @Borda @carmocca @catalys1 @daniellepintz @edenlightning @edward-io @EricWiener @fschlatt @ftorres16 @jerome-habana @justusschock @karthikrangasai @kaushikb11 @krishnakalyan3 @krshrimali @mauvilsa @nikvaessen @otaj @pre-commit-ci @puhuk @raoakarsha @rasbt @rohitgr7 @SeanNaren @s-rog @talregev @tchaton @tshu-w @twsl @weiji14 @williamFalcon @WrRan

New

@alvitawa @aminst @ankitaS11 @ar90n @Atharva-Phatak @bibhabasumohapatra @BongYang @code-review-doctor @CompRhys @Cyprien-Ricque @dependabot @digital-idiot @DN6 @donlapark @ekagra-ranjan @ethanfurman @gautierdag @georgestein @HallerPatrick @HenryLau0220 @hhsecond @himkt @hmellor @igorgad @inwaves @ishtos @JeroenDelcour @JiahaoYao @jiny419 @jinyoung-lim @JustinGoheen @jxmorris12 @Keiku @kingjuno @lsy643 @luca-medeiros @lukasugar @maciek-pioro @mads-oestergaard @manskx @martinosorb @MohammedAlkhrashi @MrShevan @myxik @naisofly @NathanielDamours @nayoungjun @niberger @nitinramvelraj @nninept @pbsds @Pragyanstha @PrajwalBorkar @Prometheos2 @rampartrange @rhjohnstone @rschireman @samz5320 @Schinkikami @semaphore-egg @shantam-8 @shenoynikhil @sisilmehta2000 @s-kumano @stanbiryukov @talregev @tanmoyio @tkonopka @vumichien @wangherr @yhl48 @YongWookHa

If we forgot somebody or you have a suggestion, find support here

Did you know?

Chuck Norris can unit-test entire applications with a single assert.