@@ -44,7 +44,7 @@ def _prepare_batch(
44
44
def supervised_training_step (
45
45
model : torch .nn .Module ,
46
46
optimizer : torch .optim .Optimizer ,
47
- loss_fn : Union [Callable , torch .nn .Module ],
47
+ loss_fn : Union [Callable [[ Any , Any ], torch . Tensor ] , torch .nn .Module ],
48
48
device : Optional [Union [str , torch .device ]] = None ,
49
49
non_blocking : bool = False ,
50
50
prepare_batch : Callable = _prepare_batch ,
@@ -57,7 +57,7 @@ def supervised_training_step(
57
57
Args:
58
58
model: the model to train.
59
59
optimizer: the optimizer to use.
60
- loss_fn: the loss function to use .
60
+ loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor .
61
61
device: device type specification (default: None).
62
62
Applies to batches after starting the engine. Model *will not* be moved.
63
63
Device can be CPU, GPU.
@@ -120,7 +120,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
120
120
def supervised_training_step_amp (
121
121
model : torch .nn .Module ,
122
122
optimizer : torch .optim .Optimizer ,
123
- loss_fn : Union [Callable , torch .nn .Module ],
123
+ loss_fn : Union [Callable [[ Any , Any ], torch . Tensor ] , torch .nn .Module ],
124
124
device : Optional [Union [str , torch .device ]] = None ,
125
125
non_blocking : bool = False ,
126
126
prepare_batch : Callable = _prepare_batch ,
@@ -134,7 +134,7 @@ def supervised_training_step_amp(
134
134
Args:
135
135
model: the model to train.
136
136
optimizer: the optimizer to use.
137
- loss_fn: the loss function to use .
137
+ loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor .
138
138
device: device type specification (default: None).
139
139
Applies to batches after starting the engine. Model *will not* be moved.
140
140
Device can be CPU, GPU.
@@ -212,7 +212,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
212
212
def supervised_training_step_apex (
213
213
model : torch .nn .Module ,
214
214
optimizer : torch .optim .Optimizer ,
215
- loss_fn : Union [Callable , torch .nn .Module ],
215
+ loss_fn : Union [Callable [[ Any , Any ], torch . Tensor ] , torch .nn .Module ],
216
216
device : Optional [Union [str , torch .device ]] = None ,
217
217
non_blocking : bool = False ,
218
218
prepare_batch : Callable = _prepare_batch ,
@@ -225,7 +225,7 @@ def supervised_training_step_apex(
225
225
Args:
226
226
model: the model to train.
227
227
optimizer: the optimizer to use.
228
- loss_fn: the loss function to use .
228
+ loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor .
229
229
device: device type specification (default: None).
230
230
Applies to batches after starting the engine. Model *will not* be moved.
231
231
Device can be CPU, GPU.
@@ -295,7 +295,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
295
295
def supervised_training_step_tpu (
296
296
model : torch .nn .Module ,
297
297
optimizer : torch .optim .Optimizer ,
298
- loss_fn : Union [Callable , torch .nn .Module ],
298
+ loss_fn : Union [Callable [[ Any , Any ], torch . Tensor ] , torch .nn .Module ],
299
299
device : Optional [Union [str , torch .device ]] = None ,
300
300
non_blocking : bool = False ,
301
301
prepare_batch : Callable = _prepare_batch ,
@@ -308,7 +308,7 @@ def supervised_training_step_tpu(
308
308
Args:
309
309
model: the model to train.
310
310
optimizer: the optimizer to use.
311
- loss_fn: the loss function to use .
311
+ loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor .
312
312
device: device type specification (default: None).
313
313
Applies to batches after starting the engine. Model *will not* be moved.
314
314
Device can be CPU, TPU.
@@ -404,7 +404,7 @@ def _check_arg(
404
404
def create_supervised_trainer (
405
405
model : torch .nn .Module ,
406
406
optimizer : torch .optim .Optimizer ,
407
- loss_fn : Union [Callable , torch .nn .Module ],
407
+ loss_fn : Union [Callable [[ Any , Any ], torch . Tensor ] , torch .nn .Module ],
408
408
device : Optional [Union [str , torch .device ]] = None ,
409
409
non_blocking : bool = False ,
410
410
prepare_batch : Callable = _prepare_batch ,
@@ -420,7 +420,7 @@ def create_supervised_trainer(
420
420
Args:
421
421
model: the model to train.
422
422
optimizer: the optimizer to use.
423
- loss_fn: the loss function to use .
423
+ loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor .
424
424
device: device type specification (default: None).
425
425
Applies to batches after starting the engine. Model *will not* be moved.
426
426
Device can be CPU, GPU or TPU.
0 commit comments