From 796bd0e011093b956406346b068e5e9f6f7546c8 Mon Sep 17 00:00:00 2001 From: solin319 Date: Thu, 30 Aug 2018 17:25:00 +0800 Subject: [PATCH] fix bug in 'device' type kvstore (#12350) * fix bug in 'device' type kvstore When we init a key after another key pushed. This key has no merged_buf_ in file 'comm.h', but the inited_ is true. So it can't pull this new key. ``` import mxnet as mx a=mx.nd.array([1,2,3], ctx=mx.gpu(0)) b=mx.nd.array([0,0,0], ctx=mx.gpu(0)) kv=mx.kv.create('device') kv.init('1', a) kv.push('1', [a,a,a,a]) kv.pull('1', b) kv.init('2', a) kv.pull('2', b) ``` * add kv test pull --- src/kvstore/comm.h | 7 +++++-- tests/python/unittest/test_kvstore.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 34cab3037ce9..61370a5bfaf3 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -459,6 +459,7 @@ class CommDevice : public Comm { void Init(int key, const NDArrayStorageType stype, const TShape& shape, int dtype = mshadow::kFloat32) override { sorted_key_attrs_.emplace_back(key, shape, dtype); + inited_ = false; } void InitBuffersAndComm(const std::vector& src) { @@ -701,8 +702,10 @@ class CommDevice : public Comm { } // Delayed allocation - as the dense merged buffer might not be used at all if push() // only sees sparse arrays - bool delay_alloc = true; - buf.merged = NDArray(shape, ctx, delay_alloc, type); + if (buf.merged.is_none()) { + bool delay_alloc = true; + buf.merged = NDArray(shape, ctx, delay_alloc, type); + } ctx_info[ctx.dev_id].second += shape.Size(); } inited_ = true; diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 921a5704d54b..28d4ec262c06 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -106,6 +106,23 @@ def check_init(kv, key): check_init(mx.kv.create(), 3) check_init(mx.kv.create(), 'a') +@with_seed() +def test_pull(): + """test pull""" + def check_pull(kv): + a = mx.nd.ones(shape) + b = mx.nd.zeros(shape) + kv.init('1', mx.nd.zeros(shape)) + kv.push('1', [a,a,a,a]) + kv.pull('1', b) + check_diff_to_scalar(b, 4) + kv.init('2', mx.nd.zeros(shape)) + kv.pull('2', b) + check_diff_to_scalar(b, 0) + + check_pull(mx.kv.create('device')) + check_pull(mx.kv.create()) + @with_seed() def test_list_kv_pair(): """list key-value pair push & pull"""