Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions regarding the batch size on memory issue #804

Open
tracy666 opened this issue Feb 26, 2025 · 2 comments
Open

Questions regarding the batch size on memory issue #804

tracy666 opened this issue Feb 26, 2025 · 2 comments

Comments

@tracy666
Copy link

Dear contributors and Dear @MUCDK ,

Following your suggestions last time, I've tried to set batch_size to lower the memory comsumption. I am experimenting with the MOSTA dataset, which is built-in by moscot.

However, I found the program still report a memory issue even if I am testing with a very small batch_size value (which is 15). And I notice that, no matter what value I set, the error message is the same:

"E0226 11:52:40.852305 2249403 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2335208976 bytes."

The byte occupation isn't lowering following smaller batch_size value. And theorectically, I can run dataset pancreas with more than 20,000 cells, so I assume a batch size higher than 10,000 should be okay for MOSTA.

Therefore I am wondering if I am using this parameter wrongly. Any suggestion is highly appreciated!

The relevant code for reproduction:

# ---------------------------- FINDING MAX BATCH SIZE ---------------------------- #
def find_max_batch_size(tp, adata, device):
    """Find the largest batch size that avoids memory issues."""
    batch_size_divisor = 2  # Start with adata.shape[0] // 2
    while batch_size_divisor <= adata.shape[0]:  
        batch_size = adata.shape[0] // batch_size_divisor
        print(f"Trying batch_size = {batch_size} (1/{batch_size_divisor} of data)")
        
        try:
            tp.solve(
                epsilon=1e-3,
                scale_cost="mean",
                max_iterations=1e7,
                device=device,
                batch_size=batch_size,
            )
            print(f"✅ Success with batch_size = {batch_size}")
            return batch_size  # Return the successful batch size
        
        except ValueError as e:
            if "RESOURCE_EXHAUSTED" in str(e):
                print(f"❌ Out of memory at batch_size = {batch_size}. Reducing batch size...")
                batch_size_divisor += 1  # Reduce batch size
            else:
                raise  # If another error occurs, raise it

    raise RuntimeError("❌ Failed to find a suitable batch size. Try reducing dataset size or increasing GPU memory.")

# ---------------------------- READING DATA ---------------------------- #
adata = mt.datasets.mosta(force_download=False)

def adapt_time(x):
  if x["timepoint"] == "E9.5":
      return 9.5
  if x["timepoint"] == "E10.5":
      return 10.5
  if x["timepoint"] == "E11.5":
      return 11.5
  raise ValueError

adata.obs["time"] = adata.obs.apply(adapt_time, axis=1).astype("category")
sc.tl.pca(adata, svd_solver="arpack")

# ---------------------------- INITIALIZING PROBLEM & dfs & set graph (omitted) ---------------------------- #

# ---------------------------- SOLVING PROBLEM ---------------------------- #
  optimal_batch_size = find_max_batch_size(tp, adata, device)

  # Solve the problem with the optimal batch size
  tp.solve(
      epsilon=1e-3,
      scale_cost="mean",
      max_iterations=1e7,
      device=device,
      batch_size=optimal_batch_size,
  )

Part of the result from Terminal:

Image

Thank you very much for your attention to this matter!

@MUCDK
Copy link
Collaborator

MUCDK commented Feb 26, 2025

Hi @tracy666,

Are you using the SpatioTemporalProblem or the TemporalProblem?

If you want to incorporate the spatial information of the mouse embryos, as we did in the paper and the tutorial, you would use the SpatioTemporalProblem. The SpatioTemporalProblem is a Fused Gromov-Wasserstein Problem which requires at least quadratic memory in full rank.

If you want to reduce the memory, you would have to set the rank. Please refer to our documentation for details on the rank parameter.

@tracy666
Copy link
Author

Dear @MUCDK ,

I sincerely appreciate your continuous support and guidance.

Sorry for the lack of clarity in my previous post. At this stage, I do not need to incorporate spatial information, so I am using TemporalProblem for the MOSTA dataset.

Upon reviewing your previous advice, I realized that I may have misunderstood your suggestion. Initially, I assumed that adjusting either batch_size or rank would be sufficient, so in my previous attempt, I only modified the batch size.

To address this, I have now conducted a new experiment where I systematically lower the rank value from 100 to 5 (I selected these values based on my basic understanding of the paper "Low-Rank Sinkhorn Factorization", but if this range is inappropriate, please let me know). For each rank value, I also gradually reduce the batch_size by halving the dataset until it reaches a threshold of 50 (if a lower batch size is advisable, I would greatly appreciate your input).

Despite these adjustments, I am still unable to find a suitable configuration to successfully run the file. Could you kindly review my approach and let me know if there is anything I might be overlooking or any further modifications I should try?

For your reference, I have attached my full code and the logger file detailing the execution process.

Thank you once again for your patience and invaluable assistance!


My code (without .txt is exactly the py file I use):

timepoint_mapping_python_version_X_pca_with_logger_with_rank.py.txt

The logger file it created:

X_pca_2025-02-26_16-11-14.log

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants