Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: image-batch-embeddings
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
# that it meets certain specs), or you can build new images using the Anyscale
# image builder at https://console.anyscale-staging.com/v2/container-images.
image_uri: anyscale/ray:2.48.0-slim-py312-cu128
image_uri: anyscale/ray:2.51.0-slim-py312-cu128
# containerfile: /home/ray/default/containerfile

# When empty, Anyscale will auto-select the instance types. You can also specify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: doggos-app
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
# that it meets certain specs), or you can build new images using the Anyscale
# image builder at https://console.anyscale-staging.com/v2/container-images.
image_uri: anyscale/ray:2.48.0-slim-py312-cu128
image_uri: anyscale/ray:2.51.0-slim-py312-cu128
# containerfile: /home/ray/default/containerfile

# When empty, Anyscale will auto-select the instance types. You can also specify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: train-image-model
# like anyscale/ray:2.43.0-slim-py312-cu125, a user-provided base image (provided
# that it meets certain specs), or you can build new images using the Anyscale
# image builder at https://console.anyscale-staging.com/v2/container-images.
image_uri: anyscale/ray:2.48.0-slim-py312-cu128
image_uri: anyscale/ray:2.51.0-slim-py312-cu128
# containerfile: /home/ray/default/containerfile

# When empty, Anyscale will auto-select the instance types. You can also specify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def __init__(self, preprocessor, model):

def __call__(self, batch, device="cuda"):
self.model.to(device)
batch["prediction"] = self.model.predict(collate_fn(batch))
batch["prediction"] = self.model.predict(collate_fn(batch, device=device))
return batch

def predict_probabilities(self, batch, device="cuda"):
self.model.to(device)
predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
predicted_probabilities = self.model.predict_probabilities(collate_fn(batch, device=device))
batch["probabilities"] = [
{
self.preprocessor.label_to_class[i]: float(prob)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,24 @@ def pad_array(arr, dtype=np.int32):
return padded_arr


def collate_fn(batch):
def collate_fn(batch, device=None):
dtypes = {"embedding": torch.float32, "label": torch.int64}
tensor_batch = {}

# If no device is provided, try to get it from Ray Train context
if device is None:
try:
device = get_device()
except RuntimeError:
# When not in Ray Train context, use CPU for testing/serving
device = "cpu"

for key in dtypes.keys():
if key in batch:
tensor_batch[key] = torch.as_tensor(
batch[key],
dtype=dtypes[key],
device=get_device(),
device=device,
)
return tensor_batch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_probabilities(self, url):
with torch.inference_mode():
embedding = self.model.get_image_features(**inputs).cpu().numpy()
outputs = self.predictor.predict_probabilities(
collate_fn({"embedding": embedding})
collate_fn({"embedding": embedding}, device=self.device)
)
return {"probabilities": outputs["probabilities"][0]}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,15 +895,24 @@
"metadata": {},
"outputs": [],
"source": [
"def collate_fn(batch):\n",
"def collate_fn(batch, device=None):\n",
" dtypes = {\"embedding\": torch.float32, \"label\": torch.int64}\n",
" tensor_batch = {}\n",
" \n",
" # If no device is provided, try to get it from Ray Train context\n",
" if device is None:\n",
" try:\n",
" device = get_device()\n",
" except RuntimeError:\n",
" # When not in Ray Train context, use CPU for testing\n",
" device = \"cpu\"\n",
" \n",
" for key in dtypes.keys():\n",
" if key in batch:\n",
" tensor_batch[key] = torch.as_tensor(\n",
" batch[key],\n",
" dtype=dtypes[key],\n",
" device=get_device(),\n",
" device=device,\n",
" )\n",
" return tensor_batch\n"
Comment on lines +898 to 917
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and avoid code duplication, it's better to import collate_fn from the doggos.model module instead of redefining it in the notebook. This ensures that any future changes to the function in the source file are automatically reflected here.

The notebook already has the necessary path setup to import from the doggos package.

Suggested change
"def collate_fn(batch, device=None):\n",
" dtypes = {\"embedding\": torch.float32, \"label\": torch.int64}\n",
" tensor_batch = {}\n",
" \n",
" # If no device is provided, try to get it from Ray Train context\n",
" if device is None:\n",
" try:\n",
" device = get_device()\n",
" except RuntimeError:\n",
" # When not in Ray Train context, use CPU for testing\n",
" device = \"cpu\"\n",
" \n",
" for key in dtypes.keys():\n",
" if key in batch:\n",
" tensor_batch[key] = torch.as_tensor(\n",
" batch[key],\n",
" dtype=dtypes[key],\n",
" device=get_device(),\n",
" device=device,\n",
" )\n",
" return tensor_batch\n"
"from doggos.model import collate_fn\n"

]
Expand Down Expand Up @@ -1047,7 +1056,7 @@
"source": [
"# Sample batch\n",
"sample_batch = train_ds.take_batch(batch_size=3)\n",
"collate_fn(batch=sample_batch)\n"
"collate_fn(batch=sample_batch, device=\"cpu\")\n"
]
},
{
Expand Down Expand Up @@ -1527,12 +1536,12 @@
"\n",
" def __call__(self, batch, device=\"cuda\"):\n",
" self.model.to(device)\n",
" batch[\"prediction\"] = self.model.predict(collate_fn(batch))\n",
" batch[\"prediction\"] = self.model.predict(collate_fn(batch, device=device))\n",
" return batch\n",
"\n",
" def predict_probabilities(self, batch, device=\"cuda\"):\n",
" self.model.to(device)\n",
" predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))\n",
" predicted_probabilities = self.model.predict_probabilities(collate_fn(batch, device=device))\n",
" batch[\"probabilities\"] = [\n",
" {\n",
" self.preprocessor.label_to_class[i]: float(prob)\n",
Expand All @@ -1551,7 +1560,8 @@
" args_fp=os.path.join(artifacts_dir, \"args.json\"), \n",
" state_dict_fp=os.path.join(artifacts_dir, \"model.pt\"),\n",
" )\n",
" return cls(preprocessor=preprocessor, model=model)\n"
" return cls(preprocessor=preprocessor, model=model)\n",
"\n"
]
},
{
Expand Down