-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use PyTorch dataloader #2222
Comments
RecordDataset = create_dataset(records)
TransfromedDataSet = dataset_fn(RecordDataset) 上面的函数调用是已经将DataSet中所有的数据已经Transform完成了么 |
@brightcoder01 这个没有。Transform是读数据的时候 on-the-fly 去做的。 |
|
https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset IterableDataset 看着就和 tf generator 一样,只需要提供一个 iter() 接口? |
|
|
Quick question: for TF and PyTorch users, do they need to follow different APIs to feed data in EDL? @QiJune |
In my opinion, |
As @Kelang-Tian writes, the I believe that TensorFlow users would like to use tf operators, and PyTorch users would like to use torch functions. It's hard to unify them. |
Can we use TensorFlow Dataset APIs to read data and feed the data into Pytorch models? for features, labels in dataset:
features = features.numpy()
labels = labels.numpy()
loss = forward(batch)
loss.backward() |
@workingloong
|
背景介绍
PyTorch的 Dataset class定义
我们可以发现,PyTorch要求Dataset必须提供
__len__
接口和__getitem__
接口,这就要求 数据集是已知长度的,并且是可以被随机访问的。这里与TensorFlow不同,TensorFlow的Dataset是可以从一个generator创建的,generator只要求用户实现
__next__
接口即可,并不要求__len__
接口和__getitem__
接口。因此,我们需要提出一种新的思路。
简单的做法
伪代码
The text was updated successfully, but these errors were encountered: