Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MNIST is OK
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 20, 2015
1 parent 7a1832c commit 910738d
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 282 deletions.
151 changes: 42 additions & 109 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ typedef void *AtomicSymbolCreator;
typedef void *SymbolHandle;
/*! \brief handle to a AtomicSymbol */
typedef void *AtomicSymbolHandle;
/*! \brief handle to a NArrayOperator */
typedef void *OperatorHandle;
/*! \brief handle to a DataIterator */
typedef void *DataIterHandle;
/*! \brief handle to an Executor */
typedef void *ExecutorHandle;
/*! \brief handle to a DataIterator */
typedef void *DataIterHandle;
/*
* \brief return str message of the last error
* all function in this file will return 0 when success
Expand Down Expand Up @@ -353,63 +351,59 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
const mx_uint ***out_shape_data,
int *complete);
//--------------------------------------------
// Part 4: operator interface on NArray
// Part 4: Executor interface
//--------------------------------------------
/*!
* \brief create operator from symbol
* \param sym the symbol to create operator from
* \param dev_mask device mask to indicate the device type
* \param dev_id the device id we want to bind the symbol to
* \param out the corresponding function handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXOpCreate(SymbolHandle sym,
int dev_mask,
int dev_id,
OperatorHandle *out);
/*!
* \brief free the operator handle
* \param op the handle to be freed
* \brief Executor forward method
*
* \param handle executor handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXOpFree(OperatorHandle op);
MXNET_DLL int MXExecutorForward(ExecutorHandle handle);
/*!
* \brief return an array to describe the arguments
* of this operator
* \param out_size the size of output array
* \param out_array the array of parameter requirments
* \brief Excecutor run backward
*
* \param handle execute handle
* \param len lenth
* \param head_grads NArray handle for heads' gradient
*
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXOpDescribeArgs(mx_uint *out_size,
int **out_array);
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle,
mx_uint len,
NArrayHandle *head_grads);

/*!
* \brief call forward on the operator
* \param op the operator handle
* \param in_data array of input narray to the operator
* \param out_data array of output NArray to hold the result
* \brief Get executor's head NArray
*
* \param handle executor handle
* \param out_size output narray vector size
* \param out out put narray handles
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXOpForward(OperatorHandle op,
NArrayHandle *in_data,
NArrayHandle *out_data);
MXNET_DLL int MXExecutorHeads(ExecutorHandle handle,
mx_uint *out_size,
NArrayHandle **out);

/*!
* \brief call backward on the operator
* \param op the operator handle
* \param grad_next array of output gradients
* \param in_data array of input narray to the operator
* \param out_data array of output narray to the operator
* \param out_grad array to holds the gradient on these input
* can be NULL if that position request is kNullOp
* \param reqs gradient request type
* \brief Generate Executor from symbol
*
* \param symbol_handle symbol handle
* \param len length
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \param out output executor handle
* \return 0 when success, -1 when failure happens
* \sa mxnet::Operator::GradReqType
*/
MXNET_DLL int MXOpBackward(OperatorHandle op,
NArrayHandle *grad_next,
NArrayHandle *in_data,
NArrayHandle *out_data,
NArrayHandle *out_grad,
mx_uint *reqs);
MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
int dev_mask,
int dev_id,
mx_uint len,
NArrayHandle *in_args,
NArrayHandle *arg_grad_store,
mx_uint *grad_req_type,
ExecutorHandle *out);

//--------------------------------------------
// Part 5: IO Interface
Expand Down Expand Up @@ -460,65 +454,4 @@ MXNET_DLL int MXIOGetData(DataIterHandle handle,
MXNET_DLL int MXIOGetLabel(DataIterHandle handle,
NArrayHandle *out);

//--------------------------------------------
// Part 56: Executor
//--------------------------------------------
/*!
* \brief Executor forward method
*
* \param handle executor handle
* \param len length of narray handles
* \param input input NArray handles
*
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorForward(ExecutorHandle handle,
mx_uint len,
NArrayHandle *input);

/**
* \brief Excecutor run backward
*
* \param handle execute handle
* \param len lenth
* \param head_grads NArray handle for heads' gradient
*
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorBackward(ExecutorHandle handle,
mx_uint len,
NArrayHandle *head_grads);

/**
* \brief Get executor's head NArray
*
* \param handle executor handle
* \param out_size output narray vector size
* \param out out put narray handles
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorHeads(ExecutorHandle handle,
mx_uint *out_size,
NArrayHandle **out);

/**
* \brief Generate Executor from symbol
*
* \param handle executor hanlde (to be generated)
* \param symbol_handle symbol handle
* \param len length
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorBind(ExecutorHandle handle,
SymbolHandle symbol_handle,
int dev_mask,
int dev_id,
mx_uint len,
NArrayHandle *in_args,
NArrayHandle *arg_grad_store,
mx_uint *grad_req_type);

#endif // MXNET_C_API_H_
11 changes: 6 additions & 5 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ class NArray {
NArray Copy(Context ctx) const;
/*!
* \brief Slice a NArray
*
*
* \param begin begin index in first dim
* \param end end index in first dim
*
*
* \return sliced NArray
*/
NArray Slice(index_t begin, index_t end) {
inline NArray Slice(index_t begin, index_t end) const {
NArray ret = *this;
CHECK_GE(shape_.ndim(), 0) << "NArray not initialized";
CHECK_GE(shape_[0], end) << "Chunk is smaller than required";
Expand All @@ -145,15 +145,16 @@ class NArray {
}
ret.offset_ = begin * length;
}
ret.shape_[0] = end - begin;
return ret;
}
/*!
* \brief Reshape current NArray
*
*
* \param shape new shape
* \return NArray in new shape
*/
NArray Reshape(const TShape &shape) {
inline NArray Reshape(const TShape &shape) const {
CHECK_GE(shape_.Size(), shape.Size()) \
<< "required shape is larger than chunk";
NArray ret = *this;
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class OperatorProperty {
const std::vector<T> &out_data) const {
int counter = 0;
std::vector<int> out_grad_index(out_grad.size());
std::vector<int> in_data_index(out_data.size());
std::vector<int> in_data_index(in_data.size());
std::vector<int> out_data_index(out_data.size());
for (size_t i = 0; i < out_grad_index.size(); ++i) {
out_grad_index[i] = counter++;
Expand Down
57 changes: 57 additions & 0 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding: utf-8
""" code for executor. """
from __future__ import absolute_import

import ctypes
from .base import _LIB
from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle
from .base import check_call
from .narray import NArray

class Executor(object):
""" Executor is the actual executing object of MXNet."""
def __init__(self, handle):
"""Init an executor from handle
Parameters
----------
handle: ExecutorHandle
ExecutorHandle generated by calling Bind
"""
if not isinstance(handle, ExecutorHandle):
raise TypeError("Handle type error")
self.handle = handle

def forward(self):
"""Do forward."""
check_call(_LIB.MXExecutorForward(self.handle))

def backward(self, grads):
"""Do backward on heads' gradient.
Parameters
----------
grads: Array of NArray
heads' gradient
"""
for obj in grads:
if not isinstance(obj, NArray):
raise TypeError("inputs must be NArray")
narray = c_array(NArrayHandle, [item.handle for item in grads])
check_call(_LIB.MXExecutorBackward(self.handle, len(grads), narray))

def heads(self):
"""list all heads' output narray
Returns
-------
A list of narray binded to the heads of executor.
"""
# TODO: think of access, make heads read only.
# (consider support read only NArray(NArrayView))
# Otherwise some of the internal might depends on out_data
# if user set the content of the head, the backward behavior can be incorrect.
out_size = mx_uint()
handles = ctypes.POINTER(NArrayHandle)()
check_call(_LIB.MXExecutorHeads(self.handle, ctypes.byref(out_size), ctypes.byref(handles)))
return [NArray(NArrayHandle(handles[i])) for i in range(out_size.value)]
Loading

0 comments on commit 910738d

Please sign in to comment.