Skip to content

Commit b6e6210

Browse files
gangsfGang Zhao
andauthored
[Template] Update image-search-and-classification to pass device for collate_fn (#58327)
Signed-off-by: Gang Zhao <[email protected]> Co-authored-by: Gang Zhao <[email protected]>
1 parent 66c857c commit b6e6210

File tree

7 files changed

+33
-14
lines changed

7 files changed

+33
-14
lines changed

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/configs/generate_embeddings.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ name: image-batch-embeddings
66
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
77
# that it meets certain specs), or you can build new images using the Anyscale
88
# image builder at https://console.anyscale-staging.com/v2/container-images.
9-
image_uri: anyscale/ray:2.48.0-slim-py312-cu128
9+
image_uri: anyscale/ray:2.51.0-slim-py312-cu128
1010
# containerfile: /home/ray/default/containerfile
1111

1212
# When empty, Anyscale will auto-select the instance types. You can also specify

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/configs/service.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ name: doggos-app
66
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
77
# that it meets certain specs), or you can build new images using the Anyscale
88
# image builder at https://console.anyscale-staging.com/v2/container-images.
9-
image_uri: anyscale/ray:2.48.0-slim-py312-cu128
9+
image_uri: anyscale/ray:2.51.0-slim-py312-cu128
1010
# containerfile: /home/ray/default/containerfile
1111

1212
# When empty, Anyscale will auto-select the instance types. You can also specify

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/configs/train_model.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ name: train-image-model
66
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
77
# that it meets certain specs), or you can build new images using the Anyscale
88
# image builder at https://console.anyscale-staging.com/v2/container-images.
9-
image_uri: anyscale/ray:2.48.0-slim-py312-cu128
9+
image_uri: anyscale/ray:2.51.0-slim-py312-cu128
1010
# containerfile: /home/ray/default/containerfile
1111

1212
# When empty, Anyscale will auto-select the instance types. You can also specify

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/doggos/doggos/infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ def __init__(self, preprocessor, model):
1313

1414
def __call__(self, batch, device="cuda"):
1515
self.model.to(device)
16-
batch["prediction"] = self.model.predict(collate_fn(batch))
16+
batch["prediction"] = self.model.predict(collate_fn(batch, device=device))
1717
return batch
1818

1919
def predict_probabilities(self, batch, device="cuda"):
2020
self.model.to(device)
21-
predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
21+
predicted_probabilities = self.model.predict_probabilities(collate_fn(batch, device=device))
2222
batch["probabilities"] = [
2323
{
2424
self.preprocessor.label_to_class[i]: float(prob)

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/doggos/doggos/model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,24 @@ def pad_array(arr, dtype=np.int32):
1616
return padded_arr
1717

1818

19-
def collate_fn(batch):
19+
def collate_fn(batch, device=None):
2020
dtypes = {"embedding": torch.float32, "label": torch.int64}
2121
tensor_batch = {}
22+
23+
# If no device is provided, try to get it from Ray Train context
24+
if device is None:
25+
try:
26+
device = get_device()
27+
except RuntimeError:
28+
# When not in Ray Train context, use CPU for testing/serving
29+
device = "cpu"
30+
2231
for key in dtypes.keys():
2332
if key in batch:
2433
tensor_batch[key] = torch.as_tensor(
2534
batch[key],
2635
dtype=dtypes[key],
27-
device=get_device(),
36+
device=device,
2837
)
2938
return tensor_batch
3039

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/doggos/doggos/serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_probabilities(self, url):
4242
with torch.inference_mode():
4343
embedding = self.model.get_image_features(**inputs).cpu().numpy()
4444
outputs = self.predictor.predict_probabilities(
45-
collate_fn({"embedding": embedding})
45+
collate_fn({"embedding": embedding}, device=self.device)
4646
)
4747
return {"probabilities": outputs["probabilities"][0]}
4848

doc/source/ray-overview/examples/e2e-multimodal-ai-workloads/notebooks/02-Distributed-Training.ipynb

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -895,15 +895,24 @@
895895
"metadata": {},
896896
"outputs": [],
897897
"source": [
898-
"def collate_fn(batch):\n",
898+
"def collate_fn(batch, device=None):\n",
899899
" dtypes = {\"embedding\": torch.float32, \"label\": torch.int64}\n",
900900
" tensor_batch = {}\n",
901+
" \n",
902+
" # If no device is provided, try to get it from Ray Train context\n",
903+
" if device is None:\n",
904+
" try:\n",
905+
" device = get_device()\n",
906+
" except RuntimeError:\n",
907+
" # When not in Ray Train context, use CPU for testing\n",
908+
" device = \"cpu\"\n",
909+
" \n",
901910
" for key in dtypes.keys():\n",
902911
" if key in batch:\n",
903912
" tensor_batch[key] = torch.as_tensor(\n",
904913
" batch[key],\n",
905914
" dtype=dtypes[key],\n",
906-
" device=get_device(),\n",
915+
" device=device,\n",
907916
" )\n",
908917
" return tensor_batch\n"
909918
]
@@ -1047,7 +1056,7 @@
10471056
"source": [
10481057
"# Sample batch\n",
10491058
"sample_batch = train_ds.take_batch(batch_size=3)\n",
1050-
"collate_fn(batch=sample_batch)\n"
1059+
"collate_fn(batch=sample_batch, device=\"cpu\")\n"
10511060
]
10521061
},
10531062
{
@@ -1527,12 +1536,12 @@
15271536
"\n",
15281537
" def __call__(self, batch, device=\"cuda\"):\n",
15291538
" self.model.to(device)\n",
1530-
" batch[\"prediction\"] = self.model.predict(collate_fn(batch))\n",
1539+
" batch[\"prediction\"] = self.model.predict(collate_fn(batch, device=device))\n",
15311540
" return batch\n",
15321541
"\n",
15331542
" def predict_probabilities(self, batch, device=\"cuda\"):\n",
15341543
" self.model.to(device)\n",
1535-
" predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))\n",
1544+
" predicted_probabilities = self.model.predict_probabilities(collate_fn(batch, device=device))\n",
15361545
" batch[\"probabilities\"] = [\n",
15371546
" {\n",
15381547
" self.preprocessor.label_to_class[i]: float(prob)\n",
@@ -1551,7 +1560,8 @@
15511560
" args_fp=os.path.join(artifacts_dir, \"args.json\"), \n",
15521561
" state_dict_fp=os.path.join(artifacts_dir, \"model.pt\"),\n",
15531562
" )\n",
1554-
" return cls(preprocessor=preprocessor, model=model)\n"
1563+
" return cls(preprocessor=preprocessor, model=model)\n",
1564+
"\n"
15551565
]
15561566
},
15571567
{

0 commit comments

Comments
 (0)