This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
/
model.py
270 lines (225 loc) · 10.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
import torch
import torchmetrics
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from torch import nn
from flash.core.registry import FlashRegistry
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess
def predict_context(func: Callable) -> Callable:
"""
This decorator is used as context manager
to put model in eval mode before running predict and reset to train after.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs) -> Any:
grad_enabled = torch.is_grad_enabled()
is_training = self.training
self.eval()
torch.set_grad_enabled(False)
result = func(self, *args, **kwargs)
if is_training:
self.train()
torch.set_grad_enabled(grad_enabled)
return result
return wrapper
class Task(LightningModule):
"""A general Task.
Args:
model: Model to use for the task.
loss_fn: Loss function for training
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `5e-5`
"""
def __init__(
self,
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
):
super().__init__()
if model is not None:
self.model = model
self.loss_fn = {} if loss_fn is None else get_callable_dict(loss_fn)
self.optimizer_cls = optimizer
self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics))
self.learning_rate = learning_rate
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")
self._data_pipeline = None
self._preprocess = None
self._postprocess = None
def step(self, batch: Any, batch_idx: int) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
x, y = batch
y_hat = self(x)
output = {"y_hat": y_hat}
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
else:
logs[name] = metric(y_hat, y)
logs.update(losses)
if len(losses.values()) > 1:
logs["total_loss"] = sum(losses.values())
return logs["total_loss"], logs
output["loss"] = list(losses.values())[0]
output["logs"] = logs
output["y"] = y
return output
def forward(self, x: Any) -> Any:
return self.model(x)
def training_step(self, batch: Any, batch_idx: int) -> Any:
output = self.step(batch, batch_idx)
self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True)
return output["loss"]
def validation_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)
@predict_context
def predict(
self,
x: Any,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""
Predict function for raw data or processed data
Args:
x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data.
data_pipeline: Use this to override the current data pipeline
Returns:
The post-processed model predictions
"""
running_stage = RunningStage.PREDICTING
data_pipeline = data_pipeline or self.data_pipeline
x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)]
x = data_pipeline.worker_preprocessor(running_stage)(x)
x = self.transfer_batch_to_device(x, self.device)
x = data_pipeline.device_preprocessor(running_stage)(x)
predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict`
predictions = data_pipeline.postprocessor(running_stage)(predictions)
return predictions
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
if isinstance(batch, tuple):
batch = batch[0]
elif isinstance(batch, list):
# Todo: Understand why stack is needed
batch = torch.stack(batch)
return self(batch)
def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
def configure_finetune_callback(self) -> List[Callback]:
return []
@property
def preprocess(self) -> Optional[Preprocess]:
return getattr(self._data_pipeline, '_preprocess_pipeline', None) or self._preprocess
@preprocess.setter
def preprocess(self, preprocess: Preprocess) -> None:
self._preprocess = preprocess
self.data_pipeline = DataPipeline(preprocess, self.postprocess)
@property
def postprocess(self) -> Postprocess:
postprocess_cls = getattr(self, "postprocess_cls", None)
return (
self._postprocess or (postprocess_cls() if postprocess_cls else None)
or getattr(self._data_pipeline, '_postprocess_pipeline', None) or Postprocess()
)
@postprocess.setter
def postprocess(self, postprocess: Postprocess) -> None:
self.data_pipeline = DataPipeline(self.preprocess, postprocess)
self._postprocess = postprocess
@property
def data_pipeline(self) -> Optional[DataPipeline]:
if self._data_pipeline is not None:
return self._data_pipeline
elif self.preprocess is not None or self.postprocess is not None:
# use direct attributes here to avoid recursion with properties that also check the data_pipeline property
return DataPipeline(self.preprocess, self.postprocess)
elif self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None:
return self.datamodule.data_pipeline
elif self.trainer is not None and hasattr(
self.trainer, 'datamodule'
) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None:
return self.trainer.datamodule.data_pipeline
return self._data_pipeline
@data_pipeline.setter
def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
self._data_pipeline = data_pipeline
if data_pipeline is not None and getattr(data_pipeline, '_preprocess_pipeline', None) is not None:
self._preprocess = data_pipeline._preprocess_pipeline
if data_pipeline is not None and getattr(data_pipeline, '_postprocess_pipeline', None) is not None:
if type(data_pipeline._postprocess_pipeline) != Postprocess:
self._postprocess = data_pipeline._postprocess_pipeline
def on_train_dataloader(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self, RunningStage.TRAINING)
self.data_pipeline._attach_to_model(self, RunningStage.TRAINING)
super().on_train_dataloader()
def on_val_dataloader(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self, RunningStage.VALIDATING)
self.data_pipeline._attach_to_model(self, RunningStage.VALIDATING)
super().on_val_dataloader()
def on_test_dataloader(self, *_) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self, RunningStage.TESTING)
self.data_pipeline._attach_to_model(self, RunningStage.TESTING)
super().on_test_dataloader()
def on_predict_dataloader(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self, RunningStage.PREDICTING)
self.data_pipeline._attach_to_model(self, RunningStage.PREDICTING)
super().on_predict_dataloader()
def on_predict_end(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self)
super().on_predict_end()
def on_fit_end(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self)
super().on_fit_end()
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# This may be an issue since here we create the same problems with pickle as in
# https://pytorch.org/docs/stable/notes/serialization.html
if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
checkpoint['data_pipeline'] = self.data_pipeline
super().on_save_checkpoint(checkpoint)
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
if 'data_pipeline' in checkpoint:
self.data_pipeline = checkpoint['data_pipeline']
@classmethod
def available_backbones(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "backbones", None)
if registry is None:
return []
return registry.available_keys()