-
Notifications
You must be signed in to change notification settings - Fork 7k
[Template] Update image-search-and-classification to pass device for collate_fn #58327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Template] Update image-search-and-classification to pass device for collate_fn #58327
Conversation
…collate_fn Signed-off-by: Gang Zhao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the image-search-and-classification template to be compatible with a recent change in Ray Train, which now requires passing the device to collate_fn. The changes correctly propagate the device parameter through the call chain in infer.py, serve.py, and model.py. The Ray image version is also updated in the configuration files. The logic in the updated collate_fn is robust, handling execution both inside and outside of a Ray Train context. My main feedback is on the Jupyter notebook, where code is duplicated from the Python modules. I've suggested replacing these duplicated code blocks with imports to improve maintainability.
| "def collate_fn(batch, device=None):\n", | ||
| " dtypes = {\"embedding\": torch.float32, \"label\": torch.int64}\n", | ||
| " tensor_batch = {}\n", | ||
| " \n", | ||
| " # If no device is provided, try to get it from Ray Train context\n", | ||
| " if device is None:\n", | ||
| " try:\n", | ||
| " device = get_device()\n", | ||
| " except RuntimeError:\n", | ||
| " # When not in Ray Train context, use CPU for testing\n", | ||
| " device = \"cpu\"\n", | ||
| " \n", | ||
| " for key in dtypes.keys():\n", | ||
| " if key in batch:\n", | ||
| " tensor_batch[key] = torch.as_tensor(\n", | ||
| " batch[key],\n", | ||
| " dtype=dtypes[key],\n", | ||
| " device=get_device(),\n", | ||
| " device=device,\n", | ||
| " )\n", | ||
| " return tensor_batch\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and avoid code duplication, it's better to import collate_fn from the doggos.model module instead of redefining it in the notebook. This ensures that any future changes to the function in the source file are automatically reflected here.
The notebook already has the necessary path setup to import from the doggos package.
| "def collate_fn(batch, device=None):\n", | |
| " dtypes = {\"embedding\": torch.float32, \"label\": torch.int64}\n", | |
| " tensor_batch = {}\n", | |
| " \n", | |
| " # If no device is provided, try to get it from Ray Train context\n", | |
| " if device is None:\n", | |
| " try:\n", | |
| " device = get_device()\n", | |
| " except RuntimeError:\n", | |
| " # When not in Ray Train context, use CPU for testing\n", | |
| " device = \"cpu\"\n", | |
| " \n", | |
| " for key in dtypes.keys():\n", | |
| " if key in batch:\n", | |
| " tensor_batch[key] = torch.as_tensor(\n", | |
| " batch[key],\n", | |
| " dtype=dtypes[key],\n", | |
| " device=get_device(),\n", | |
| " device=device,\n", | |
| " )\n", | |
| " return tensor_batch\n" | |
| "from doggos.model import collate_fn\n" |
kouroshHakha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stmp
…collate_fn (ray-project#58327) Signed-off-by: Gang Zhao <[email protected]> Co-authored-by: Gang Zhao <[email protected]>
…collate_fn (ray-project#58327) Signed-off-by: Gang Zhao <[email protected]> Co-authored-by: Gang Zhao <[email protected]>
…collate_fn (ray-project#58327) Signed-off-by: Gang Zhao <[email protected]> Co-authored-by: Gang Zhao <[email protected]> Signed-off-by: Aydin Abiar <[email protected]>
…collate_fn (ray-project#58327) Signed-off-by: Gang Zhao <[email protected]> Co-authored-by: Gang Zhao <[email protected]>
…collate_fn (ray-project#58327) Signed-off-by: Gang Zhao <[email protected]> Co-authored-by: Gang Zhao <[email protected]> Signed-off-by: Future-Outlier <[email protected]>
Description
Ray train has some change that is not backward compatible. We are updating the multimodal AI workload template to use the latest Ray image and pass device to collate_fn
Test
Tested in this workspace.