diff --git a/main_moco.py b/main_moco.py index d7ea97359..c2f9f0022 100755 --- a/main_moco.py +++ b/main_moco.py @@ -158,7 +158,7 @@ def print_pass(*args): print("=> creating model '{}'".format(args.arch)) model = moco.builder.MoCo( models.__dict__[args.arch], - args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) + args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, args.batch_size) print(model) if args.distributed: diff --git a/moco/builder.py b/moco/builder.py index 7d80fe996..96855ad24 100644 --- a/moco/builder.py +++ b/moco/builder.py @@ -8,7 +8,7 @@ class MoCo(nn.Module): Build a MoCo model with: a query encoder, a key encoder, and a queue https://arxiv.org/abs/1911.05722 """ - def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): + def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, fixed_batch_size = 256): """ dim: feature dimension (default: 128) K: queue size; number of negative keys (default: 65536) @@ -20,6 +20,7 @@ def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): self.K = K self.m = m self.T = T + self.fixed_batch_size = fixed_batch_size # create the encoders # num_classes is the output fc dimension @@ -57,11 +58,11 @@ def _dequeue_and_enqueue(self, keys): batch_size = keys.shape[0] ptr = int(self.queue_ptr) - assert self.K % batch_size == 0 # for simplicity + assert self.K % self.fixed_batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.queue[:, ptr:ptr + batch_size] = keys.T - ptr = (ptr + batch_size) % self.K # move pointer + ptr = (ptr + self.fixed_batch_size) % self.K # move pointer self.queue_ptr[0] = ptr