Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated readme for vit #99

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ivy_models/vit/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ Getting started
# Instantiate vit_h_14 model
ivy_vit_h_14 = vit_h_14(pretrained=True)

# Convert the Torch image tensor to an Ivy tensor and adjust dimensions
img = ivy.asarray(torch_img.permute((0, 2, 3, 1)), dtype="float32", device="gpu:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Really sorry for the late review here, it slipped through somehow 😅
Just a small detail, could you please load an image (or the nparray file, whichever is relevant) from the image folder in the repo, and update the returned results if necessary? Thanks!


# Compile the Ivy vit_h_14 model with the Ivy image tensor
ivy_vit_h_14.compile(args=(img,))

# Pass the Ivy image tensor through the Ivy vit_h_14 model and apply softmax
output = ivy.softmax(ivy_vit_h_14(img))

# Get the indices of the top 3 classes from the output probabilities
classes = ivy.argsort(output[0], descending=True)[:3]

# Retrieve the logits corresponding to the top 3 classes
logits = ivy.gather(output[0], classes)

print("Indices of the top 3 classes are:", classes)
print("Logits of the top 3 classes are:", logits)
print("Categories of the top 3 classes are:", [categories[i] for i in classes.to_list()])


`Indices of the top 3 classes are: ivy.array([457, 655, 691], dev=gpu:0)``
`Logits of the top 3 classes are: ivy.array([0.03149041, 0.02733098, 0.02412809], dev=gpu:0)``
`Categories of the top 3 classes are: ['bow tie', 'miniskirt', 'oxygen mask']``

The pretrained vit_h_14 model is now ready to be used, and is compatible with any other PyTorch code

Citation
Expand Down