Skip to content

Commit

Permalink
fix bug in 'device' type kvstore (apache#12350)
Browse files Browse the repository at this point in the history
* fix bug in 'device' type kvstore

When we init a key after another key pushed. This key has no merged_buf_ in file 'comm.h', but the inited_ is true. So it can't pull this new key. 
```
import mxnet as mx
a=mx.nd.array([1,2,3], ctx=mx.gpu(0))
b=mx.nd.array([0,0,0], ctx=mx.gpu(0))
kv=mx.kv.create('device')
kv.init('1', a)
kv.push('1', [a,a,a,a])
kv.pull('1', b)
kv.init('2', a)
kv.pull('2', b)
```

* add kv test pull
  • Loading branch information
solin319 authored and anirudh2290 committed Sep 19, 2018
1 parent 77e173f commit 796bd0e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ class CommDevice : public Comm {
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
sorted_key_attrs_.emplace_back(key, shape, dtype);
inited_ = false;
}

void InitBuffersAndComm(const std::vector<NDArray>& src) {
Expand Down Expand Up @@ -701,8 +702,10 @@ class CommDevice : public Comm {
}
// Delayed allocation - as the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
if (buf.merged.is_none()) {
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
}
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ def check_init(kv, key):
check_init(mx.kv.create(), 3)
check_init(mx.kv.create(), 'a')

@with_seed()
def test_pull():
"""test pull"""
def check_pull(kv):
a = mx.nd.ones(shape)
b = mx.nd.zeros(shape)
kv.init('1', mx.nd.zeros(shape))
kv.push('1', [a,a,a,a])
kv.pull('1', b)
check_diff_to_scalar(b, 4)
kv.init('2', mx.nd.zeros(shape))
kv.pull('2', b)
check_diff_to_scalar(b, 0)

check_pull(mx.kv.create('device'))
check_pull(mx.kv.create())

@with_seed()
def test_list_kv_pair():
"""list key-value pair push & pull"""
Expand Down

0 comments on commit 796bd0e

Please sign in to comment.