Skip to content

Commit

Permalink
[MXNET-374] handle row_sparse weight in parameter and trainer (apache…
Browse files Browse the repository at this point in the history
…#11001)

* + rsp parameter

* draft

* Fix optimizer pickle

* refactor and document

* add test for save load with cast_stype

* refactor trainer tests

* add test

* add back test

* raise error for load params

* add comment

* remove print

* fix doc

* CR comments

* CR comments

* change error

* remove cast stype

* fix test

* add reset kvstore to trainer

* lint

* add test to CI

* add more checks
  • Loading branch information
eric-haibin-lin authored and zheng-da committed Jun 28, 2018
1 parent f871305 commit a95880d
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 138 deletions.
1 change: 1 addition & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon
}

test_ubuntu_cpu_python2() {
Expand Down
9 changes: 9 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ class HybridBlock(Block):
Refer `Hybrid tutorial <http://mxnet.io/tutorials/gluon/hybrid.html>`_ to see
the end-to-end usage.
"""
def __init__(self, prefix=None, params=None):
super(HybridBlock, self).__init__(prefix=prefix, params=params)
Expand Down Expand Up @@ -879,6 +880,14 @@ def __init__(self, outputs, inputs, params=None):
"Input symbols must be variable, but %s is an output of operators"%str(i)
input_names.add(i.name)

# check if any symbol is row_sparse
row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse']
for i in out:
for j in i.get_internals():
assert(j.attr("__storage_type__") != str(row_sparse_storage)), \
"SymbolBlock doesn't support Parameter '%s' because its storage " \
"type is 'row_sparse'." % j.name

for i in out.list_arguments():
if i not in input_names:
self.params.get(i, allow_deferred_init=True)
Expand Down
123 changes: 111 additions & 12 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Parameter(object):
Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
init : Initializer, default None
Initializer of this parameter. Will use the global initializer by default.
stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
The storage type of the parameter.
grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
The storage type of the parameter's gradient.
Expand All @@ -99,12 +101,13 @@ class Parameter(object):
"""
def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False,
differentiable=True, grad_stype='default'):
differentiable=True, stype='default', grad_stype='default'):
self._var = None
self._data = None
self._grad = None
self._ctx_list = None
self._ctx_map = None
self._trainer = None
self._deferred_init = ()
self._differentiable = differentiable
self._allow_deferred_init = allow_deferred_init
Expand All @@ -116,10 +119,14 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
self.wd_mult = wd_mult
self.grad_req = grad_req
self.init = init
assert grad_stype in ['default', 'row_sparse', 'csr'], \
"grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \
" but got '%s'" % (name, grad_stype)
# sparse related storage type information
valid_stypes = ['default', 'row_sparse', 'csr']
assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must be " \
"one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, grad_stype)
assert stype in valid_stypes, "stype for Parameter '%s' must be " \
"one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, stype)
self._grad_stype = grad_stype
self._stype = stype


def __repr__(self):
Expand Down Expand Up @@ -162,6 +169,16 @@ def shape(self, new_shape):

self._shape = new_shape

def _set_trainer(self, trainer):
""" Set the trainer this parameter is associated with. """
# trainer cannot be replaced for sparse params
if self._stype != 'default' and self._trainer and trainer and self._trainer is not trainer:
raise RuntimeError(
"Failed to set the trainer for Parameter '%s' because it was already set. " \
"More than one trainers for a %s Parameter is not supported." \
%(self.name, self._stype))
self._trainer = trainer

def _check_and_get(self, arr_list, ctx):
if arr_list is not None:
if ctx is list:
Expand Down Expand Up @@ -194,6 +211,20 @@ def _check_and_get(self, arr_list, ctx):
"because the later does not include Parameters of " \
"nested child Blocks"%(self.name))

def _get_row_sparse(self, arr_list, ctx, row_id):
""" Get row_sparse data from row_sparse parameters based on row_id. """
# get row sparse params based on row ids
if not isinstance(row_id, ndarray.NDArray):
raise TypeError("row_id must have NDArray type, but %s is given"%(type(row_id)))
if not self._trainer:
raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \
"Trainer is created with it."%self.name)
results = self._check_and_get(arr_list, ctx)

# fetch row sparse params from the trainer
self._trainer._row_sparse_pull(self, results, row_id)
return results

def _load_init(self, data, ctx):
"""(Re)initializes by loading from data."""
if self.shape:
Expand All @@ -208,6 +239,8 @@ def _load_init(self, data, ctx):
"Failed loading Parameter '%s' from saved params: " \
"dtype incompatible expected %s vs saved %s"%(
self.name, str(self.dtype), str(data.dtype))
if self._stype != data.stype:
data = data.tostype(self._stype)
if isinstance(ctx, Context):
ctx = [ctx]
if self._data is None:
Expand Down Expand Up @@ -243,7 +276,7 @@ def _finish_deferred_init(self):
with autograd.pause():
if data is None:
data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
ctx=context.cpu())
ctx=context.cpu(), stype=self._stype)
initializer.create(default_init)(
initializer.InitDesc(self.name, {'__init__': init}), data)

Expand Down Expand Up @@ -271,12 +304,18 @@ def _init_grad(self):
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
stype=self._grad_stype) for i in self._data]

autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)
autograd.mark_variables(self._check_and_get(self._data, list),
self._grad, self.grad_req)

def _reduce(self):
"""Reduce data from multiple context."""
block = self.list_data()
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
if self._stype == 'default':
block = self.list_data()
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu())
data = self.row_sparse_data(all_row_ids)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down Expand Up @@ -380,12 +419,58 @@ def set_data(self, data):
self._deferred_init = self._deferred_init[:3] + (data,)
return

for arr in self.list_data():
# if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync
if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore:
if self not in self._trainer._params_to_init:
self._trainer._reset_kvstore()

for arr in self._check_and_get(self._data, list):
arr[:] = data

def row_sparse_data(self, row_id):
"""Returns a copy of the 'row_sparse' parameter on the same context as row_id's.
The copy only retains rows whose ids occur in provided row ids.
The parameter must have been initialized on this context before.
Parameters
----------
row_id: NDArray
Row ids to retain for the 'row_sparse' parameter.
Returns
-------
NDArray on row_id's context
"""
if self._stype != 'row_sparse':
raise RuntimeError("Cannot return a copy of Parameter %s via row_sparse_data() " \
"because its storage type is %s. Please use data() instead." \
%(self.name, self._stype))
return self._get_row_sparse(self._data, row_id.context, row_id)

def list_row_sparse_data(self, row_id):
"""Returns copies of the 'row_sparse' parameter on all contexts, in the same order
as creation. The copy only retains rows whose ids occur in provided row ids.
The parameter must have been initialized before.
Parameters
----------
row_id: NDArray
Row ids to retain for the 'row_sparse' parameter.
Returns
-------
list of NDArrays
"""
if self._stype != 'row_sparse':
raise RuntimeError("Cannot return copies of Parameter '%s' on all contexts via " \
"list_row_sparse_data() because its storage type is %s. Please " \
"use data() instead." % (self.name, self._stype))
return self._get_row_sparse(self._data, list, row_id)

def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
initialized on this context before.
initialized on this context before. For sparse parameters, use
:py:meth:`Parameter.row_sparse_data` instead.
Parameters
----------
Expand All @@ -396,11 +481,25 @@ def data(self, ctx=None):
-------
NDArray on ctx
"""
if self._stype != 'default':
raise RuntimeError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \
"because its storage type is %s. Please use row_sparse_data() " \
"instead." % (self.name, str(ctx), self._stype))
return self._check_and_get(self._data, ctx)

def list_data(self):
"""Returns copies of this parameter on all contexts, in the same order
as creation."""
as creation. For sparse parameters, use :py:meth:`Parameter.list_row_sparse_data`
instead.
Returns
-------
list of NDArrays
"""
if self._stype != 'default':
raise RuntimeError("Cannot return copies of Parameter '%s' on all contexts via " \
"list_data() because its storage type is %s. Please use " \
"row_sparse_data() instead." % (self.name, self._stype))
return self._check_and_get(self._data, list)

def grad(self, ctx=None):
Expand Down Expand Up @@ -447,7 +546,7 @@ def var(self):
if self._var is None:
self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype,
lr_mult=self.lr_mult, wd_mult=self.wd_mult,
init=self.init)
init=self.init, stype=self._stype)
return self._var

def cast(self, dtype):
Expand Down
Loading

0 comments on commit a95880d

Please sign in to comment.