Skip to content

Simple usage example? (eg. MNIST) #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ameya98 opened this issue Apr 6, 2024 · 5 comments
Open

Simple usage example? (eg. MNIST) #381

ameya98 opened this issue Apr 6, 2024 · 5 comments

Comments

@ameya98
Copy link

ameya98 commented Apr 6, 2024

Hi, thanks for making this! Is there a simple training example, perhaps with MNIST or another small dataset?

@pablo2909
Copy link

+1 , that would be really cool !

@GrantMcConachie
Copy link

Any updates on this?

@simon-bachhuber
Copy link

Maybe it helps:

from dataclasses import dataclass

import grain.python as pygrain


@dataclass
class MyDataset:
    data: list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]


data_source = MyDataset(list(range(6)))

sampler = pygrain.IndexSampler(
    len(data_source),
    shuffle=True,
    seed=2,
    shard_options=pygrain.NoSharding(),
    num_epochs=3,
)

batch_size = 3
dl = pygrain.DataLoader(
    data_source=data_source, sampler=sampler, operations=[pygrain.Batch(batch_size)]
)

for sample in dl:
    print(sample)

"""Prints:
[0 4 1]
[2 3 5]
[4 3 0]
[1 2 5]
[5 0 4]
[3 2 1]
"""

@cnguyen10
Copy link

There is another example of grain using with tensorflow_datasets (or tfds) here: TFDS with Jax.

@selamw1
Copy link

selamw1 commented Dec 5, 2024

This new tutorial explores different data loading strategies including Grain for a simple image classification task based on the MNIST dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants