Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve symbol bindings #5870

Merged
merged 14 commits into from
Jun 3, 2017
Merged

Conversation

reminisce
Copy link
Contributor

@reminisce reminisce commented Apr 17, 2017

This PR

  1. moves mx.symbol.simple_bind to backend for creating ndarrays, etc.
  2. _bind_ith_exec calls simple_bind instead of bind.

Benchmark Environment: p2.xlarge
Step 1. Create an executor group as shared_group
Step 2. Call DataExecutorGroup constructor 500 times and pass the executor group created in Step 1 in as shared_group

Timed Step 2 results:
New simple_bind: 18.3 ms per executor group creation
Old simple_bind: 19.4 ms per executor group creation

Benchmark script

import mxnet as mx
import time
from mxnet.module.executor_group import DataParallelExecutorGroup
import argparse

parser = argparse.ArgumentParser(description="Train RNN on Penn Tree Bank",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-layers', type=int, default=5,
                    help='number of stacked RNN layers')
parser.add_argument('--num-hidden', type=int, default=200,
                    help='hidden layer size')
parser.add_argument('--num-embed', type=int, default=200,
                    help='embedding layer size')
parser.add_argument('--batch-size', type=int, default=32,
                    help='the batch size.')
parser.add_argument('--repeat', type=int, default=500, help='repeated times for binding')
args = parser.parse_args()


def sym_gen(seq_len):
    len_vocab = 50
    stack = mx.rnn.SequentialRNNCell()
    for i in range(args.num_layers):
        stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_' % i))

    data = mx.sym.Variable('data')
    label = mx.sym.Variable('softmax_label')
    embed = mx.sym.Embedding(data=data, input_dim=len_vocab,
                             output_dim=args.num_embed, name='embed')

    stack.reset()
    outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

    pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden))
    pred = mx.sym.FullyConnected(data=pred, num_hidden=len_vocab, name='pred')

    label = mx.sym.Reshape(label, shape=(-1,))
    pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

    return pred


def benchmark_simple_bind():
    contexts = [mx.cpu(0), mx.cpu(1)]
    workload = [1] * len(contexts)
    batch_size = args.batch_size
    bucket_size = 10
    data_shapes = [('data', (batch_size, bucket_size))]
    label_shapes = [('softmax_label', (batch_size, bucket_size))]

    # generate an rnn sym with num_layers=max_num_layers
    sym = sym_gen(bucket_size)
    arg_names = sym.list_arguments()
    input_names = [name[0] for name in data_shapes] + [name[0] for name in label_shapes]
    shared_arg_names = [name for name in arg_names if name not in input_names]
    shared_exec_group = DataParallelExecutorGroup(symbol=sym, contexts=contexts,
                                                  workload=workload, data_shapes=data_shapes,
                                                  label_shapes=label_shapes, param_names=shared_arg_names,
                                                  for_training=True, inputs_need_grad=False)

    sym_list = []
    data_shape_list = []
    label_shape_list = []
    repeat = args.repeat
    for i in range(repeat):
        bucket_size = i % 5 + 1
        data_shape_list.append([('data', (batch_size, bucket_size))])
        label_shape_list.append([('softmax_label', (batch_size, bucket_size))])
        sym_list.append(sym_gen(bucket_size))

    start = time.time()
    for sym, data_shapes, label_shapes in zip(sym_list, data_shape_list, label_shape_list):
        exec_group = DataParallelExecutorGroup(symbol=sym, contexts=contexts,
                                               workload=workload, data_shapes=data_shapes,
                                               label_shapes=label_shapes, param_names=shared_arg_names,
                                               for_training=True, inputs_need_grad=False,
                                               shared_group=shared_exec_group)
    end = time.time()
    print "Repeated %d times, cost %0.1f ms per executor group creation" \
          % (repeat, 1000.0 * (end - start) / repeat)

if __name__ == "__main__":
    benchmark_simple_bind()

if group2ctx is not None:
if attr_dict is None:
attr_dict = self.attr_dict()
arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved to backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

num_shared_exec_aux_states = 0
shared_exec_aux_state_handles = ctypes.POINTER(NDArrayHandle)()
else:
shared_exec_in_arg_handles = [nd.handle for nd in shared_exec.arg_arrays]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These data structures are available in backend I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right. GraphExecutor::data_entry_ has all the NDArrays of the full graph. In order to make use that, I want confirm that shared_exec's indexed graph is same as the current executor's indexed graph, right?

}
}
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm new to MXNet,
why can't we use arg_shape_map and arg_dtype_map directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand the question. But the shape attr in the graph is designed as a vector. Its indices correspond to other data structures of the graph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I thought the shape attr is a map.

NDArray** shared_exec_in_arg_ptrs =
reinterpret_cast<NDArray**>(shared_exec_in_arg_handles);
for (mx_uint i = 0; i < num_shared_exec_in_args; ++i) {
shared_exec_in_args.push_back(*shared_exec_in_arg_ptrs[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use constructor instead of pushback

if (mutable_nodes.count(nid)) { // aux_states
if (has_shared_exec) {
const NDArray& aux_nd = shared_exec_aux_states[aux_top];
CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve error msg

@reminisce reminisce force-pushed the improve_symbol_bind branch 2 times, most recently from 8bf8f28 to 66db99a Compare May 5, 2017 04:56
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::vector<OpReqType>& grad_req_types,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need a map for storage_type, too...

"""Binds current symbol to get an executor, allocate all the arguments needed.
def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
param_names=None, shared_exec=None, shared_data_arrays=None, **kwargs):
"""This function is DEPRECATED.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is also deprecated??

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll also need a storage_type_dict for the sparse feature. Maybe we can add that in our branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, maybe I accidentally copied the description of simple_bind_v1 and forgot to delete deprecation statement.

arg_shape_map[provided_arg_shape_names[i]] =
TShape(provided_arg_shape_data+provided_arg_shape_idx[i],
provided_arg_shape_data+provided_arg_shape_idx[i+1]);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Map also has emplace() to avoid copy assign constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.


// create para name set for sharing data array memory
std::unordered_set<std::string> param_name_set;
for (mx_uint i = 0; i < num_param_names; ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::unordered_set<std::string> param_name_set(num_param_names) to avoid resizing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@eric-haibin-lin
Copy link
Member

Could you resolve the conflicts and fix lint?

@reminisce reminisce force-pushed the improve_symbol_bind branch 2 times, most recently from 2867f6d to e78166d Compare May 13, 2017 22:30
@reminisce
Copy link
Contributor Author

@eric-haibin-lin Fixed.

@piiswrong
Copy link
Contributor

Generally looks good. Please remove *_v1 in front end and add more tests.

provided_grad_req_types = [c_str(grad_req)]
elif isinstance(grad_req, list):
if len(grad_req) == 0:
raise RuntimeError('grad_req in simple_bind cannot be an empty list')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this and check if len(grad_req) == num_inputs in backend.

@reminisce reminisce changed the title [WIP]Improve symbol bindings Improve symbol bindings May 14, 2017
@reminisce reminisce changed the title Improve symbol bindings [WIP]Improve symbol bindings May 14, 2017
@eric-haibin-lin
Copy link
Member

Is this good to merge?

@reminisce
Copy link
Contributor Author

@eric-haibin-lin This still needs tests for executor group. After discussing with @piiswrong , we can do it in this way. Submit this PR in your repo and you can use it for sparse tensor, while I finish writing tests for the executor group. Let me know what you think.

@eric-haibin-lin
Copy link
Member

eric-haibin-lin commented May 15, 2017

@reminisce are you gonna submit another PR to dmlc/mxnet after your test is done? What's the implication when the sparse branch merges with dmlc/mxnet?

@reminisce
Copy link
Contributor Author

@eric-haibin-lin If I submit this PR to your repo, I will submit test PR to your repo too. You can merge with dmlc/master after your sparse tensor is done with my changes on symbol binding. Does this work for you?

@eric-haibin-lin
Copy link
Member

Yeah that sounds good!

@reminisce
Copy link
Contributor Author

I submitted this PR to eric-haibin-lin#31 for sparse tensor development. Please do not merge this PR into dmlc/master.

@reminisce reminisce force-pushed the improve_symbol_bind branch 2 times, most recently from 8682567 to a07229b Compare May 23, 2017 04:36
@reminisce reminisce changed the title [WIP]Improve symbol bindings Improve symbol bindings May 23, 2017
@reminisce reminisce force-pushed the improve_symbol_bind branch 3 times, most recently from f0677b2 to 8f8889f Compare June 1, 2017 02:29
if (nd.is_none()) return default_ctx;
return nd.ctx();
};
std::vector<Context> in_arg_ctxes(in_args.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? I think the vector after transform will have 2*in_args.size() length.
Shouldn't you use reserve?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the documentation, transform would do the following thing, where the output container's space should have been resized, instead of being reserved. If reserved, the code does not crash but the output vector's effective size does not change. Is my understanding correct?

template<class InputIt, class OutputIt, class UnaryOperation>
OutputIt transform(InputIt first1, InputIt last1, OutputIt d_first, 
                   UnaryOperation unary_op)
{
    while (first1 != last1) {
        *d_first++ = unary_op(*first1++);
    }
    return d_first;
}

const std::string& arg_name = idx[nid].source->attrs.name;
if (mutable_nodes.count(nid)) { // aux_states
if (nullptr != shared_exec) {
const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
Copy link
Contributor

@piiswrong piiswrong Jun 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if arg_name doesn't exist in shared_exec->aux_state_map()? how was it handled before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This situation should not happen. If happened, map::at() function would throw an exception, which will be handled by the c_api function.

Before this, it was done in python, where the arg_name is expected to exist in shared_exec.arg_dict; if not, the dict would throw an exception.
https://github.com/dmlc/mxnet/blob/35d5e54c43aa6bdb189839d4b1a1978d85263fc7/python/mxnet/module/executor_group.py#L625

} else { // in_args
if (shared_arg_names.count(arg_name)) { // model parameter
if (nullptr != shared_exec) {
const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@piiswrong
Copy link
Contributor

I'll merge this for now. We may need to improve some error messages if users get confused by it.

@eric-haibin-lin
Copy link
Member

Yes better error message is needed, i am aware that it segfaults when infer shape fails in graph_executor.

@piiswrong
Copy link
Contributor

why would it segfault? @reminisce Could you fix that first?

@reminisce
Copy link
Contributor Author

reminisce commented Jun 2, 2017

The functions InferShape and InferType could fail if users provide insufficient information. If it fails, the program continues with 0-size ndarrays and may cause segfaults.

I will add checks for shape and type inference results.

@reminisce reminisce force-pushed the improve_symbol_bind branch 3 times, most recently from f542fe0 to 66be337 Compare June 2, 2017 21:23
reminisce and others added 14 commits June 2, 2017 23:01
Add init functions for simple bind in graph_executor

Add simple_bind c_api

Add simple bind c-api

Assign zeros to in_args, arg_grads, and aux_states

Add simple_bind2 python interface

Fix python interface bugs

Interface changes

Fix

Fix core dump

Add bind_ith_exec c_api

Change simple_bind2

Fix seg fault

Finish simple_bind

Change _bind_ith_exec

Refactor simple_bind initialization flow for bind

Consolidate bind and simple_bind graph init flow

Fix bug

Clean up

Add comments

Clean up

Clean up

Minor correction

Rename APIs in graph executor

Refactor

Rebase

Delete deprecated functions

Move more front-end work to backend

Bug fix

Fix failed tests

Minor fix

Fix lint

Fix lint

Revert unnecessary changes

Revert

Revert

Clean up

Fix lint

Fix bind_ith_exec calling simple_bind

Fix bugs for _bind_ith_exec
* Add unit test

* Fix

* Small fix
* Add bucketing test

* Skip pylint

* Use cpu to train
@piiswrong piiswrong merged commit 10af1c7 into apache:master Jun 3, 2017
Guneet-Dhillon pushed a commit to Guneet-Dhillon/mxnet that referenced this pull request Sep 13, 2017
* Initial checkin

Add init functions for simple bind in graph_executor

Add simple_bind c_api

Add simple bind c-api

Assign zeros to in_args, arg_grads, and aux_states

Add simple_bind2 python interface

Fix python interface bugs

Interface changes

Fix

Fix core dump

Add bind_ith_exec c_api

Change simple_bind2

Fix seg fault

Finish simple_bind

Change _bind_ith_exec

Refactor simple_bind initialization flow for bind

Consolidate bind and simple_bind graph init flow

Fix bug

Clean up

Add comments

Clean up

Clean up

Minor correction

Rename APIs in graph executor

Refactor

Rebase

Delete deprecated functions

Move more front-end work to backend

Bug fix

Fix failed tests

Minor fix

Fix lint

Fix lint

Revert unnecessary changes

Revert

Revert

Clean up

Fix lint

Fix bind_ith_exec calling simple_bind

Fix bugs for _bind_ith_exec

* Add unit test (apache#1)

* Add unit test

* Fix

* Small fix

* Fix lint

* Fix lint

* Fix bugs of missing ndarrays in shared_buffer

* Fix lint

* Simple bind (apache#3)

* Add bucketing test

* Skip pylint

* Use cpu to train

* Fix bug

* Remove merge message

* Fix lint

* Add logging to test_bucketing.py

* Reduce model size (apache#4)

* Add checks for shape/type inferences

* Add printing error messages for shape/type inference failure
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants