Skip to content

Conversation

@bw4sz
Copy link
Collaborator

@bw4sz bw4sz commented May 14, 2025

Summary

This is a large PR aimed to create flexible dataset classes and predict_tile dataloader strategies. The deepforest.dataset.TreeDataset is one of the oldest parts of the codebase. Over time we have added other dataset classes, like TileDataset and RasterDataset, but there is no unifying structure or organization. There is a reason that single GPU predict_tile and the TreeDataset logic lasted 4 years, changing the structure required touching nearly every file in the codebase. It was far easier to redesign the datasets knowing that we have immediate need for refactoring for 2.0. Unpicking them in a halfway complete process would have made the next steps difficult for anyone else to contribute.

Motivation

  • We want to process predict_tile faster and easier.
  • Greater clarity and organization among dataset classes.
  • Prepare datasets for differing geometry types for milestone 2.0
  • Cleanup unused imports and naming based on config argument changes in Hydra integration #1035

Desired Dataset Functionality

Related

I tried #1047 and found that it didn't play with pytorch lightning, and was batch_size = 1.

Major improvements

  • Introduced a single PredictionDataset class that establishes a general structure for prediction and post-processing
  • Split out training and cropmodel datasets into separate classes
  • Renamed TileRaster to SingleImage and MemoryRaster to TiledRaster for greater clarity
  • Add a dataloader-strategy (should this be more accurately, 'dataset-strategy'?) argument to main.predict_tile
  • Add a MultiImage approach to combine batches of images to take advantage of multi-image batches of GPUs

Minor improvements

  • Significant code cleanup, renaming, and deletion of used imports

Co-pilot summary of code changes

This pull request introduces significant updates to the DeepForest project, including enhancements to prediction scaling, dataset handling, and configuration management. The changes focus on improving usability, performance, and modularity by refining documentation, restructuring datasets, and updating configuration files.

Enhancements to Prediction Scaling:

  • Updated the prediction documentation to introduce three new dataset strategies (single, batch, window) for balancing CPU/GPU memory and utilization during inference. These strategies are configurable via the dataloader_strategy parameter.
  • Removed the Dask-based multi-GPU scaling example, simplifying the documentation and focusing on the new dataset strategies.

Dataset Refactoring:

  • Removed the TreeDataset, TileDataset, and RasterDataset classes from src/deepforest/dataset.py and modularized the BoundingBoxDataset into a new file, src/deepforest/datasets/cropmodel.py. This improves code organization and clarity. [1] [2]
  • Updated the import path for TileDataset in the prediction example in docs/user_guide/16_prediction.md to reflect the new file structure.

Configuration Updates:

  • Renamed the pin_images parameter to preload_images in src/deepforest/conf/config.yaml for better clarity and added a new predict.pin_memory configuration option to control memory pinning during prediction. [1] [2]

Code Simplification:

  • Removed unused imports from src/deepforest/callbacks.py, streamlining the file and reducing unnecessary dependencies.

Faster tests

  • the retinanet.check_model() function was performing a forward pass on creation that slowed things down.

Refactored on_validation_epoch_end

Next steps

  • Local tests passing
  • Update documentation
  • Demonstrate multi-gpu scaling and functionality on 2 different datasets
  • Rebase main
  • Code review by 2 DeepForest contributors
  • Squash commits

Additional issues need to be considered.

  • main.predict_batch is an anti-pattern that exists outside of pytorch lightning trainer.predict. I left it since it does not interfere.
  • I did not make a separate training base class, this will need to be done to achieve 2.0 model integration, with all of its connectors among torchvision/transformers input types.

@bw4sz bw4sz force-pushed the multi_gpu_predict_tiles branch from 95319eb to 86105b3 Compare May 19, 2025 20:29
@bw4sz
Copy link
Collaborator Author

bw4sz commented May 20, 2025

Here are the figures for dataset. The multi-gpu, multi-processing confirmed is faster.

BOEM data

1 GPU

image

3 GPU

Screenshot 2025-05-20 at 3 20 23 PM

NEON Data

  • Note must be run on a100s, to fit into memory!
Profiling Results Comparison:
============================================================================================================================================
+----------+-----------+------------+--------------+-----------------+----------------+
| Device   |   Workers | Strategy   |   Num Images |   Mean Time (s) |   Std Time (s) |
+==========+===========+============+==============+=================+================+
| cuda     |         0 | single     |            4 |           54.09 |              0 |
+----------+-----------+------------+--------------+-----------------+----------------+
| cuda     |         0 | batch      |            4 |           13.58 |              0 |
+----------+-----------+------------+--------------+-----------------+----------------+
| cuda     |         5 | batch      |            4 |           16.39 |              0 |
+----------+-----------+------------+--------------+-----------------+----------------+
| cuda     |         0 | single     |           10 |           54.92 |              0 |
+----------+-----------+------------+--------------+-----------------+----------------+
| cuda     |         0 | batch      |           10 |           27.33 |              0 |
+----------+-----------+------------+--------------+-----------------+----------------+
| cuda     |         5 | batch      |           10 |           19.98 |              0 |
+----------+-----------+------------+--------------+-----------------+----------------+

2 GPUs
image

@bw4sz
Copy link
Collaborator Author

bw4sz commented May 21, 2025

okay, @ethanwhite , @jveitchmichaelis and @henrykironde, i've met the criteria I set out above. Clearly this is quite large and complex, but it is ready to be discussed.

@bw4sz bw4sz changed the title [WIP] Dataset redesign for multi-gpu, multi-processing and multi-geometry Dataset redesign for multi-gpu, multi-processing and multi-geometry May 22, 2025
@bw4sz bw4sz linked an issue May 22, 2025 that may be closed by this pull request
@bw4sz bw4sz dismissed ethanwhite’s stale review June 3, 2025 18:30

out on vacation, all pieces are resolved above, and if anything lingers we can address when he returns.

@bw4sz
Copy link
Collaborator Author

bw4sz commented Jun 4, 2025

This is not ready to be merged.

1. if paths is a list, but dataloader_strategy is single, it silently just predicts the first path. complete.

2. I'm seeing some window issues with batch strategy compared to single complete locally, will test on BOEM data.

image getting way more predictions with batch and a couple strange window issues.

@bw4sz
Copy link
Collaborator Author

bw4sz commented Jun 9, 2025

@jveitchmichaelis and @henrykironde this ready to be merged. I have confirmed and corrected edge cases.

@bw4sz
Copy link
Collaborator Author

bw4sz commented Jun 10, 2025

This should be passing now, @jveitchmichaelis let's get this merged, i'm worried we will start getting behind on other PRs and end up discouraging contributions if we let this massive thing hang. We can do follow up PRs if need. Once all tests pass let's have one more look and be done here.

@jveitchmichaelis jveitchmichaelis self-requested a review June 11, 2025 03:16
Copy link
Collaborator

@jveitchmichaelis jveitchmichaelis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM unless @henrykironde has any further comment. Maybe just fix that mosiac typo.

class_recall: a pandas dataframe of class level recall and precision with class sizes
"""

# If all empty ground truth, return 0 recall and precision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"0 precision and undefined recall"?



print(f"{mosaic_df.shape[0]} predictions kept after non-max suppression")
def mosiac(predictions, iou_threshold=0.1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename mosiac -> mosaic

@henrykironde henrykironde merged commit af9458b into main Jun 11, 2025
11 checks passed
@henrykironde henrykironde deleted the multi_gpu_predict_tiles branch June 11, 2025 07:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

5 participants