diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index 7bf6dceae3..f77d0f8e56 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -84,8 +84,7 @@ def from_task( metrics = [metrics] backbone_kwargs["logging_metrics"] = metrics - if not backbone_kwargs: - backbone_kwargs = {} + backbone_kwargs = backbone_kwargs or {} adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs))