Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add Index to IO Batch #504

Merged
merged 26 commits into from
Nov 12, 2015
Merged

Add Index to IO Batch #504

merged 26 commits into from
Nov 12, 2015

Conversation

junranhe
Copy link
Contributor

@junranhe junranhe commented Nov 6, 2015

在mxioiter添加了每个batch中添加了index,用于在python中查找其他辅助信息

@@ -80,6 +80,9 @@ class PrefetcherIter : public IIterator<DataBatch> {
batch.data[i].FlatTo2D<cpu, real_t>());
(*dptr)->num_batch_padd = batch.num_batch_padd;
}
for (size_t i = 0; i < batch.batch_size; ++i) {
(*dptr)->index[i] = batch.inst_index[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::copy

@tqchen tqchen changed the title io batch 添加 index Add Index to IO Batch Nov 6, 2015
@tqchen
Copy link
Member

tqchen commented Nov 6, 2015

Thanks for the contribution! This is helpful. Given that there is another IO refactor going on #468 I will not merge it for now.
We can rebase from the on going refactor repo and push changes to there, or wait until #468 is merged in

@pluskid

@tqchen
Copy link
Member

tqchen commented Nov 6, 2015

Sorry It is #468 I referred to

@pluskid
Copy link
Contributor

pluskid commented Nov 6, 2015

Sorry for holding this back... I will implement a simple peak method hopefully today or tomorrow and get that pr merged.

Meanwhile: could you explain a little bit more on The use case of this interface? I'm thinking maybe similar information could be simply obtained by chaining the data iter with the python enumerate function. Also other data iter needs to be modified to add the same interface.

@tqchen
Copy link
Member

tqchen commented Nov 6, 2015

I think what was wanted is because each image have a unique index, so we can use it to remap labels. Say we store the imid -> bounding box in an independent map on python, and want to attach that information to the iter, while still being able to use the preproc pipeline

@pluskid
Copy link
Contributor

pluskid commented Nov 6, 2015

I see. Thanks!

@tqchen
Copy link
Member

tqchen commented Nov 8, 2015

@junranhe #468 is merged, can you rebase and update this PR? Thanks

@junranhe
Copy link
Contributor Author

junranhe commented Nov 9, 2015

不好意思,昨天周日没上git,我会尽快重新提交的

@junranhe
Copy link
Contributor Author

junranhe commented Nov 9, 2015

@tqchen
int MXDataIterGetIndex(DataIterHandle handle, uint64_t *_out_index, size_t *out_size)返回后怎么用 uint64_t *_out_index 直接转换为numpy.array?我没有找到array直接用指针初始化的方法哦,所以之前才分为MxDataIterGetBatchsize() 和 MXDataIterGetIndex(),这样能先初始化bumpy.array,再把数据拷到里面

@tqchen
Copy link
Member

tqchen commented Nov 9, 2015

hmm, I think you are right, that is easier.

https://github.com/dmlc/mxnet/blob/master/python/mxnet/base.py#L131 This function should also be modified to do what was needed. But do need to copy here as well instead of sharing.

@junranhe
Copy link
Contributor Author

junranhe commented Nov 9, 2015

@tqchen 在MXDataGetIndex里面 DataBatch为const的,所以返回*out_index = db.index.data()是需要const_cast,介意不?我一直觉得强制转换不优雅

何俊然 added 6 commits November 9, 2015 17:03
@junranhe
Copy link
Contributor Author

junranhe commented Nov 9, 2015

@tqchen 请问r_test 通不过,是要注意些什么吗?R的环境细节我不熟悉。。。

@thirdwing
Copy link
Contributor

@junranhe Because you delete the rpkg in Makefile, it is used to build the R package.

@junranhe
Copy link
Contributor Author

@thirdwing 好,改了,同时想问下可不可以吧 MXNET_CUDNN_PATH这个参数放回来?实际开发环境有多套cudnn,指定路径比较方便:)

@junranhe
Copy link
Contributor Author

@tqchen 这个pr还有需要吗?要不撤了,有需要我在重新pr,rebase得比较danteng,还有checkfail 在 一个cpp_test里面,感觉跟io没什么关系啊?我比较忙没时间深究。。。

tqchen added a commit that referenced this pull request Nov 12, 2015
@tqchen tqchen merged commit d3bd2d5 into apache:master Nov 12, 2015
@tqchen
Copy link
Member

tqchen commented Nov 12, 2015

Sorry about the delayed reply, this is merged

@tqchen tqchen mentioned this pull request Nov 14, 2015
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants