From 67cf36403ffbbc58ac38600411397f3086a65830 Mon Sep 17 00:00:00 2001 From: Nico Trummer Date: Tue, 30 Jan 2024 09:38:58 +0100 Subject: [PATCH] Fix minibatch size problem in SOLO --- modules/solo.nf | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/modules/solo.nf b/modules/solo.nf index ad39fe2..1b4fc20 100644 --- a/modules/solo.nf +++ b/modules/solo.nf @@ -32,19 +32,21 @@ process SOLO { results = [] - batch_sizes = adata.obs["batch"].value_counts() - batches = "${batches.join(" ")}".split(" ") for batch in batches: - # If only one cell is assigned to a minibatch, solo will crash - # https://discuss.pytorch.org/t/error-expected-more-than-1-value-per-channel-when-training/26274 - batch_size = batch_sizes[batch] - default_minibatch_size = 128 - minibatch_size_correction = -1 if batch_size % default_minibatch_size == 1 else 0 - solo = scvi.external.SOLO.from_scvi_model(scvi_model, restrict_to_batch=batch) - solo.train(batch_size=default_minibatch_size + minibatch_size_correction) - + + minibatch_size = 128 + worked = False + while not worked and minibatch_size > 100: + try: + solo.train(batch_size=minibatch_size) + worked = True + except ValueError: + print("Minibatch size did not work, trying again with smaller minibatch size") + minibatch_size -= 1 + pass + batch_res = solo.predict() batch_res["doublet_label"] = solo.predict(False)