Skip to content

Commit ea7cb1d

Browse files
Apply the changes (#3069)
1 parent 8ec312c commit ea7cb1d

File tree

6 files changed

+525
-522
lines changed

6 files changed

+525
-522
lines changed

Diff for: ignite/engine/__init__.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _prepare_batch(
4444
def supervised_training_step(
4545
model: torch.nn.Module,
4646
optimizer: torch.optim.Optimizer,
47-
loss_fn: Union[Callable, torch.nn.Module],
47+
loss_fn: Union[Callable[[Any, Any], torch.Tensor], torch.nn.Module],
4848
device: Optional[Union[str, torch.device]] = None,
4949
non_blocking: bool = False,
5050
prepare_batch: Callable = _prepare_batch,
@@ -57,7 +57,7 @@ def supervised_training_step(
5757
Args:
5858
model: the model to train.
5959
optimizer: the optimizer to use.
60-
loss_fn: the loss function to use.
60+
loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor.
6161
device: device type specification (default: None).
6262
Applies to batches after starting the engine. Model *will not* be moved.
6363
Device can be CPU, GPU.
@@ -120,7 +120,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
120120
def supervised_training_step_amp(
121121
model: torch.nn.Module,
122122
optimizer: torch.optim.Optimizer,
123-
loss_fn: Union[Callable, torch.nn.Module],
123+
loss_fn: Union[Callable[[Any, Any], torch.Tensor], torch.nn.Module],
124124
device: Optional[Union[str, torch.device]] = None,
125125
non_blocking: bool = False,
126126
prepare_batch: Callable = _prepare_batch,
@@ -134,7 +134,7 @@ def supervised_training_step_amp(
134134
Args:
135135
model: the model to train.
136136
optimizer: the optimizer to use.
137-
loss_fn: the loss function to use.
137+
loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor.
138138
device: device type specification (default: None).
139139
Applies to batches after starting the engine. Model *will not* be moved.
140140
Device can be CPU, GPU.
@@ -212,7 +212,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
212212
def supervised_training_step_apex(
213213
model: torch.nn.Module,
214214
optimizer: torch.optim.Optimizer,
215-
loss_fn: Union[Callable, torch.nn.Module],
215+
loss_fn: Union[Callable[[Any, Any], torch.Tensor], torch.nn.Module],
216216
device: Optional[Union[str, torch.device]] = None,
217217
non_blocking: bool = False,
218218
prepare_batch: Callable = _prepare_batch,
@@ -225,7 +225,7 @@ def supervised_training_step_apex(
225225
Args:
226226
model: the model to train.
227227
optimizer: the optimizer to use.
228-
loss_fn: the loss function to use.
228+
loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor.
229229
device: device type specification (default: None).
230230
Applies to batches after starting the engine. Model *will not* be moved.
231231
Device can be CPU, GPU.
@@ -295,7 +295,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
295295
def supervised_training_step_tpu(
296296
model: torch.nn.Module,
297297
optimizer: torch.optim.Optimizer,
298-
loss_fn: Union[Callable, torch.nn.Module],
298+
loss_fn: Union[Callable[[Any, Any], torch.Tensor], torch.nn.Module],
299299
device: Optional[Union[str, torch.device]] = None,
300300
non_blocking: bool = False,
301301
prepare_batch: Callable = _prepare_batch,
@@ -308,7 +308,7 @@ def supervised_training_step_tpu(
308308
Args:
309309
model: the model to train.
310310
optimizer: the optimizer to use.
311-
loss_fn: the loss function to use.
311+
loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor.
312312
device: device type specification (default: None).
313313
Applies to batches after starting the engine. Model *will not* be moved.
314314
Device can be CPU, TPU.
@@ -404,7 +404,7 @@ def _check_arg(
404404
def create_supervised_trainer(
405405
model: torch.nn.Module,
406406
optimizer: torch.optim.Optimizer,
407-
loss_fn: Union[Callable, torch.nn.Module],
407+
loss_fn: Union[Callable[[Any, Any], torch.Tensor], torch.nn.Module],
408408
device: Optional[Union[str, torch.device]] = None,
409409
non_blocking: bool = False,
410410
prepare_batch: Callable = _prepare_batch,
@@ -420,7 +420,7 @@ def create_supervised_trainer(
420420
Args:
421421
model: the model to train.
422422
optimizer: the optimizer to use.
423-
loss_fn: the loss function to use.
423+
loss_fn: the loss function that receives `y_pred` and `y`, and returns the loss as a tensor.
424424
device: device type specification (default: None).
425425
Applies to batches after starting the engine. Model *will not* be moved.
426426
Device can be CPU, GPU or TPU.

Diff for: tests/ignite/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def gloo_hvd_executor():
397397
],
398398
),
399399
],
400+
scope="class",
400401
)
401402
def distributed(request, local_rank, world_size):
402403
if request.param in ("nccl", "gloo_cpu", "gloo"):

0 commit comments

Comments
 (0)