Skip to content

Commit

Permalink
fix array data iter provide size not properly reported (see apache#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Nov 9, 2015
1 parent 315c1a4 commit 722eb36
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ function ArrayDataProvider(data::Any, label::Any; batch_size::Int=0, shuffle::Bo
end

function provide_data(provider::ArrayDataProvider)
return collect(zip(provider.data_names, map(size, provider.data_arrays)))
return collect(zip(provider.data_names, map(size, provider.data_batch)))
end
function provide_label(provider::ArrayDataProvider)
return collect(zip(provider.label_names, map(size, provider.label_arrays)))
return collect(zip(provider.label_names, map(size, provider.label_batch)))
end
get_batch_size(provider::ArrayDataProvider) = provider.batch_size

Expand Down
7 changes: 7 additions & 0 deletions test/unittest/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ function test_arrays_impl(data::Vector, label::Vector, provider::mx.ArrayDataPro
batch_size = mx.get_batch_size(provider)
idx_all = 1:batch_size:sample_count

for (d1, (_, d2)) in zip(data, mx.provide_data(provider))
@test size(d1)[1:end-1] == d2[1:end-1]
end
for (d1, (_, d2)) in zip(label, mx.provide_label(provider))
@test size(d1)[1:end-1] == d2[1:end-1]
end

info("IO::Array::#data=$(length(data)),#label=$(length(label)),batch_size=$batch_size")
for (idx, batch) in zip(idx_all, provider)
data_batch = [x[[Colon() for i=1:ndims(x)-1]..., idx:min(idx+batch_size-1,sample_count)] for x in data]
Expand Down

0 comments on commit 722eb36

Please sign in to comment.