Skip to content

Commit 27c7630

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1f228ef commit 27c7630

File tree

6 files changed

+14
-17
lines changed

6 files changed

+14
-17
lines changed

docs/source/solo/utils.rst

-1
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,3 @@ forward
257257
~~~~~~~
258258
.. automethod:: solo.utils.positional_encoding.Summer.forward
259259
:noindex:
260-

main_linear.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def main(cfg: DictConfig):
203203
"logger": wandb_logger if cfg.wandb.enabled else None,
204204
"callbacks": callbacks,
205205
"enable_checkpointing": False,
206-
"strategy": DDPStrategy(find_unused_parameters=False)
207-
if cfg.strategy == "ddp"
208-
else cfg.strategy,
206+
"strategy": (
207+
DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy
208+
),
209209
}
210210
)
211211
trainer = Trainer(**trainer_kwargs)

main_pretrain.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ def main(cfg: DictConfig):
221221
"logger": wandb_logger if cfg.wandb.enabled else None,
222222
"callbacks": callbacks,
223223
"enable_checkpointing": False,
224-
"strategy": DDPStrategy(find_unused_parameters=False)
225-
if cfg.strategy == "ddp"
226-
else cfg.strategy,
224+
"strategy": (
225+
DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy
226+
),
227227
}
228228
)
229229
trainer = Trainer(**trainer_kwargs)

scripts/pretrain/cifar/all4one.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ data:
2828
dataset: cifar100 # change here for cifar10
2929
train_path: "./datasets/"
3030
val_path: "./datasets/"
31-
format: "image_folder"
31+
format: "image_folder"
3232
num_workers: 4
3333
optimizer:
3434
name: "lars"

solo/methods/base.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,11 @@ def configure_optimizers(self) -> Tuple[List, List]:
389389
if idxs_no_scheduler:
390390
partial_fn = partial(
391391
static_lr,
392-
get_lr=scheduler["scheduler"].get_lr
393-
if isinstance(scheduler, dict)
394-
else scheduler.get_lr,
392+
get_lr=(
393+
scheduler["scheduler"].get_lr
394+
if isinstance(scheduler, dict)
395+
else scheduler.get_lr
396+
),
395397
param_group_indexes=idxs_no_scheduler,
396398
lrs_to_replace=[self.lr] * len(idxs_no_scheduler),
397399
)

solo/utils/positional_encodings.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ def forward(self, tensor):
100100
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
101101
emb_x = get_emb(sin_inp_x).unsqueeze(1)
102102
emb_y = get_emb(sin_inp_y)
103-
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
104-
tensor.type()
105-
)
103+
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(tensor.type())
106104
emb[:, :, : self.channels] = emb_x
107105
emb[:, :, self.channels : 2 * self.channels] = emb_y
108106

@@ -165,9 +163,7 @@ def forward(self, tensor):
165163
emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
166164
emb_y = get_emb(sin_inp_y).unsqueeze(1)
167165
emb_z = get_emb(sin_inp_z)
168-
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
169-
tensor.type()
170-
)
166+
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(tensor.type())
171167
emb[:, :, :, : self.channels] = emb_x
172168
emb[:, :, :, self.channels : 2 * self.channels] = emb_y
173169
emb[:, :, :, 2 * self.channels :] = emb_z

0 commit comments

Comments
 (0)