diff --git a/Makefile b/Makefile index 66fc540ea185..fa2e3c214c1c 100644 --- a/Makefile +++ b/Makefile @@ -63,14 +63,14 @@ endif BIN = tests/test_simple_engine OBJ = narray_function_cpu.o -OBJCXX11 = reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o +OBJCXX11 = batch_norm_cpu.o reshape_cpu.o dag_engine.o simple_engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o + CUOBJ += batch_norm_gpu.o reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o endif .PHONY: clean all test lint doc @@ -105,6 +105,8 @@ convolution_cpu.o: src/operator/convolution.cc convolution_gpu.o: src/operator/convolution.cu reshape_cpu.o: src/operator/reshape.cc reshape_gpu.o: src/operator/reshape.cu +batch_norm_cpu.o: src/operator/batch_norm.cc +batch_norm_gpu.o: src/operator/batch_norm.cu io.o: src/io/io.cc iter_mnist.o: src/io/iter_mnist.cc diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h old mode 100644 new mode 100755 index db3e4b0162ac..7a8d4033becc --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -358,6 +358,16 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array); +/*! + * \brief List auxiliary states in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array); /*! * \brief Compose the symbol on other symbols. * @@ -406,6 +416,9 @@ MXNET_DLL int MXSymbolGrad(SymbolHandle sym, * \param out_shape_size sizeof the returning array of out_shapes * \param out_shape_ndim returning array of shape dimensions of eachs input shape. * \param out_shape_data returning array of pointers to head of the input shape. + * \param aux_shape_size sizeof the returning array of aux_shapes + * \param aux_shape_ndim returning array of shape dimensions of eachs auxiliary shape. + * \param aux_shape_data returning array of pointers to head of the auxiliary shape. * \param complete whether infer shape completes or more information is needed. * \return 0 when success, -1 when failure happens */ @@ -420,6 +433,9 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, + mx_uint *aux_shape_size, + const mx_uint **aux_shape_ndim, + const mx_uint ***aux_shape_data, int *complete); //-------------------------------------------- // Part 4: Executor interface @@ -428,9 +444,10 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, * \brief Executor forward method * * \param handle executor handle + * \param is_train bool value to indicate whether the forward pass is for evaluation * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXExecutorForward(ExecutorHandle handle); +MXNET_DLL int MXExecutorForward(ExecutorHandle handle, bool is_train); /*! * \brief Excecutor run backward * @@ -466,6 +483,8 @@ MXNET_DLL int MXExecutorHeads(ExecutorHandle handle, * \param in_args in args array * \param arg_grad_store arg grads handle array * \param grad_req_type grad req array + * \param aux_states_len length of auxiliary states + * \param aux_states auxiliary states array * \param out output executor handle * \return 0 when success, -1 when failure happens */ @@ -476,6 +495,8 @@ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, NArrayHandle *in_args, NArrayHandle *arg_grad_store, mx_uint *grad_req_type, + mx_uint aux_states_len, + NArrayHandle *aux_states, ExecutorHandle *out); //-------------------------------------------- diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index c9d206104437..fc72b0f91e18 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -414,6 +414,7 @@ struct NArrayFunctionReg } // namespace mxnet namespace dmlc { +/*!\brief traits */ DMLC_DECLARE_TRAITS(has_saveload, mxnet::NArray, true); } // namespace dmlc #endif // MXNET_NARRAY_H_ diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h old mode 100644 new mode 100755 index 57c8c6c85098..5fe7f2d7ee5a --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -80,12 +80,15 @@ class Operator { * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. * \param out_data array of output data, pointer is used to indicate that this is holder * the space of TBlob in out_data must be pre-allocated with InferShape + * \param aux_states Auxiliary states of operator. Normally operator doesn't + * need, epecial case like Batch Norm requires. * \sa OpReqType, OpContext */ virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) = 0; + const std::vector &out_data, + const std::vector &aux_states) = 0; /*! * \brief Perform a Backward Operation, write gradient to the in_grad. * @@ -111,6 +114,7 @@ class Operator { * \param out_data the array of output data. * \param req request types of the saving operation, can be all types. * \param in_grad the array of gradient we need to write to. + * \param aux_states Auxiliary states of operator. Normally operator doesn't need * \sa OperatorProperty, OpReqType, OpContext */ virtual void Backward(const OpContext &ctx, @@ -118,7 +122,8 @@ class Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_states) { LOG(FATAL) << "Backward is not implemented"; } }; @@ -158,6 +163,13 @@ class OperatorProperty { virtual std::vector ListReturns() const { return {"output"}; } + /*! + * \brief Get name of auxilary states of Operator + * \return name of return values. + */ + virtual std::vector ListAuxiliaryStates() const { + return {}; + } /*! \return number of real return values of the Operator */ virtual int NumReturns() const { return 1; @@ -189,11 +201,14 @@ class OperatorProperty { * * \param out_shape the shape of outputs of the operator * InferShape will modify the vector to fill output TShape + * \param aux_shape the shape of auxiliary states of the operator + * InferShape will modify the vector to fill output TShape * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const = 0; + std::vector *out_shape, + std::vector *aux_shape) const = 0; /*! * \brief Copy this OperatorProperty. * \return a pointer to the copied OperatorProperty diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h old mode 100644 new mode 100755 index 15b143c87bf7..97eed74a53be --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -128,10 +128,12 @@ class StaticGraph { * * \param topo_order The topological order of node index, as created by TopoSort. * \param node_out_shapes The shapes of the each outputs of nodes in the graph. + * \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph. * \return if the shape inference is successful, return true, else return false. */ bool InferNodeShapes(const std::vector &topo_order, - std::vector > *node_out_shapes) const; + std::vector > *node_out_shapes, + std::vector > *node_aux_shapes) const; /*! * \brief infer the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator @@ -144,10 +146,13 @@ class StaticGraph { * * \param out_shape the shape of outputs of the operator * InferShape will modify the vector to fill output TShape + * \param aux_shape the shape of auxiliary states of the operator + * InferShape will modify the vector to fill output TShape * \return if the shape inference is successful, return true, else return false. */ bool InferShape(std::vector* in_shape, - std::vector* out_shape) const; + std::vector* out_shape, + std::vector* aux_shape) const; /*! * \brief Add a full backward pass in the static graph. * This function will add gradient nodes for each heads, @@ -204,6 +209,8 @@ class Symbol { std::vector ListArguments() const; /*! \return get the descriptions of outputs for this symbol */ std::vector ListReturns() const; + /*! \return get the descriptions of auxiliary data for this symbol */ + std::vector ListAuxiliaryStates() const; /*! * \brief get the index th element from the returned tuple. * \param index index of multi output @@ -272,22 +279,26 @@ class Symbol { * common practice: set the shape of data input, and usually weight's shape can be infered * * \param out_shapes Use to store the infered shapes of outputs. + * \param aux_shapes Use to store the infered shapes of auxiliary states * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ bool InferShape(std::vector *arg_shapes, - std::vector *out_shapes) const; + std::vector *out_shapes, + std::vector *aux_shapes) const; /*! * \brief infer the shapes by providing shapes of known arguments. * \param known_arg_shapes map of argument name to shape of arguments with known shapes. * \param arg_shapes used to store infered shapes of arguments. * \param out_shapes used to store infered shapes of outputs. + * \param aux_shapes Use to store the infered shapes of auxiliary states * \return true if the shape inference is successful, false if there is not enough information. * \throws dmlc::Error if the known arg_shapes are inconsistent. */ bool InferShape(const std::unordered_map &known_arg_shapes, std::vector *arg_shapes, - std::vector *out_shapes) const; + std::vector *out_shapes, + std::vector *aux_shapes) const; /*! * \brief get number of outputs of this symbol * \return number of outputs @@ -378,7 +389,7 @@ class Executor { * \brief Perform a Forward operation of Operator * After this operation, user can get the result by using function head. */ - virtual void Forward() = 0; + virtual void Forward(bool is_train) = 0; /*! * \brief Perform a Backward operation of the Operator. * This must be called after Forward. @@ -400,13 +411,15 @@ class Executor { * \param in_args the NArray that stores the input arguments to the symbol. * \param arg_grad_store NArray that is used to store the gradient output of the input arguments. * \param grad_req_type requirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}. + * \param aux_states NArray that is used as internal state in op * \return a new executor. */ static Executor *Bind(Symbol symbol, Context ctx, const std::vector &in_args, const std::vector &arg_grad_store, - const std::vector &grad_req_type); + const std::vector &grad_req_type, + const std::vector &aux_states); }; // class operator } // namespace mxnet #endif // MXNET_SYMBOLIC_H_ diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py old mode 100644 new mode 100755 index 235c16c542f3..17df30190ce8 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -23,9 +23,16 @@ def __init__(self, handle): raise TypeError("Handle type error") self.handle = handle - def forward(self): - """Do forward.""" - check_call(_LIB.MXExecutorForward(self.handle)) + def forward(self, is_train=True): + """Do forward. + + Parameters + ---------- + is_train: bool + whether this forward is for evaluation purpose + Note: for test only network, please indicate in Bind (TODO) + """ + check_call(_LIB.MXExecutorForward(self.handle, is_train)) def backward(self, grads): """Do backward on heads' gradient. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py old mode 100644 new mode 100755 index c35f84fda25d..3fb5ae665764 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, fixme +# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments """Symbol support of mxnet""" from __future__ import absolute_import @@ -123,6 +123,20 @@ def list_returns(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [py_str(sarr[i]) for i in range(size.value)] + def list_auxiliary_states(self): + """List all auxiliary states in the symbool. + + Returns + ------- + args: list of string + List of all the auxiliary + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXSymbolListAuxiliaryStates( + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [py_str(sarr[i]) for i in range(size.value)] + def infer_shape(self, *args, **kwargs): """Infer the shape of outputs and arguments of given known shapes of arguments. @@ -147,6 +161,9 @@ def infer_shape(self, *args, **kwargs): out_shapes : list of tuple or None List of shapes of outputs. The order is in the same order as list_returns() + aux_shapes : list of tuple or None + List of shapes of outputs. + The order is in the same order as list_auxiliary() """ # pylint: disable=too-many-locals if len(args) != 0 and len(kwargs) != 0: @@ -176,6 +193,9 @@ def infer_shape(self, *args, **kwargs): out_shape_size = mx_uint() out_shape_ndim = ctypes.POINTER(mx_uint)() out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() + aux_shape_size = mx_uint() + aux_shape_ndim = ctypes.POINTER(mx_uint)() + aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() complete = ctypes.c_int() check_call(_LIB.MXSymbolInferShape( self.handle, len(indptr) - 1, @@ -188,13 +208,18 @@ def infer_shape(self, *args, **kwargs): ctypes.byref(out_shape_size), ctypes.byref(out_shape_ndim), ctypes.byref(out_shape_data), + ctypes.byref(aux_shape_size), + ctypes.byref(aux_shape_ndim), + ctypes.byref(aux_shape_data), ctypes.byref(complete))) if complete.value != 0: arg_shapes = [ tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)] out_shapes = [ tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)] - return (arg_shapes, out_shapes) + aux_shapes = [ + tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)] + return (arg_shapes, out_shapes, aux_shapes) else: return (None, None) # pylint: enable=too-many-locals @@ -212,7 +237,7 @@ def debug_str(self): self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) - def bind(self, ctx, args, args_grad, reqs): + def bind(self, ctx, args, args_grad, reqs, aux_states=None): """bind current symbol to get an executor. Parameters @@ -225,15 +250,20 @@ def bind(self, ctx, args, args_grad, reqs): input args' gradient reqs: Array of enum graident requirements + aux_states: Array of NArray + input auxiliary states to the symbol """ # TODO(bing): consider a more friendly interface # For example, pass in args_grad by dict enum = {"null" : 0, "write_to" : 1, "in_place":2, "add_to" : 3} if not isinstance(ctx, Context): raise TypeError("Context type error") + if aux_states == None: + aux_states = [] args_handle = c_array(NArrayHandle, [item.handle for item in args]) args_grad_handle = c_array(NArrayHandle, [item.handle for item in args_grad]) reqs_array = c_array(mx_uint, [mx_uint(enum[item]) for item in reqs]) + aux_args_handle = c_array(NArrayHandle, [item.handle for item in aux_states]) handle = ExecutorHandle() check_call(_LIB.MXExecutorBind(self.handle, mx_uint(ctx.device_mask), @@ -242,6 +272,8 @@ def bind(self, ctx, args, args_grad, reqs): args_handle, args_grad_handle, reqs_array, + len(aux_states), + aux_args_handle, ctypes.byref(handle))) return Executor(handle) diff --git a/python/test_io.py b/python/test_io.py index d15d4cc32fcd..dfeb3f67c293 100644 --- a/python/test_io.py +++ b/python/test_io.py @@ -1,21 +1,41 @@ -#pylint: skip-file +# pylint: skip-file import mxnet as mx import numpy as np -import os +import os, gzip +import pickle as pickle +import sys +import get_data -dataiter = mx.io.MNISTIterator(path_img="/home/tianjun/data/mnist/train-images-idx3-ubyte", - path_label="/home/tianjun/data/mnist/train-labels-idx1-ubyte", - batch_size=100, shuffle=1, silent=1, input_flat="flat") +# prepare data +get_data.GetMNIST_ubyte() -dataiter.beforefirst() +batch_size = 100 +train_dataiter = mx.io.MNISTIter( + image="data/train-images-idx3-ubyte", + label="data/train-labels-idx1-ubyte", + batch_size=batch_size, shuffle=1, flat=1, silent=0, seed=10) +val_dataiter = mx.io.MNISTIter( + image="data/t10k-images-idx3-ubyte", + label="data/t10k-labels-idx1-ubyte", + batch_size=batch_size, shuffle=0, flat=1, silent=0) -idx = 0 -while dataiter.next(): - info = "Batch %d" % (idx) - idx += 1 - print info - ''' - label = dataiter.getlabel() - print label.numpy - ''' +def test_MNISTIter_loop(): + nbatch = 60000 / batch_size + batch_count = 0 + for data, label in train_dataiter: + batch_count += 1 + assert(nbatch == batch_count) + +def test_MNISTIter_reset(): + train_dataiter.reset() + train_dataiter.iter_next() + label_0 = train_dataiter.getlabel().numpy.flatten() + train_dataiter.iter_next() + train_dataiter.iter_next() + train_dataiter.iter_next() + train_dataiter.iter_next() + train_dataiter.reset() + train_dataiter.iter_next() + label_1 = train_dataiter.getlabel().numpy.flatten() + assert(sum(label_0 - label_1) == 0) diff --git a/src/c_api.cc b/src/c_api.cc old mode 100644 new mode 100755 index 85f3e52d81cd..69db585b29a7 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -47,11 +47,11 @@ struct MXAPIThreadLocalEntry { /*! \brief result holder for returning handles */ std::vector ret_handles; /*! \brief result holder for returning shapes */ - std::vector arg_shapes, out_shapes; + std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning shape dimensions */ - std::vector arg_shape_ndim, out_shape_ndim; + std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ - std::vector arg_shape_data, out_shape_data; + std::vector arg_shape_data, out_shape_data, aux_shape_data; // helper function to setup return value of shape array inline static void SetupShapeArrayReturn( const std::vector &shapes, @@ -556,6 +556,22 @@ int MXSymbolListReturns(SymbolHandle symbol, API_END(); } +int MXSymbolListAuxiliaryStates(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array) { + Symbol *s = static_cast(symbol); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + ret->ret_vec_str = std::move(s->ListAuxiliaryStates()); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); + } + *out_size = static_cast(ret->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); + API_END(); +} + int MXSymbolCompose(SymbolHandle sym, const char *name, mx_uint num_args, @@ -606,6 +622,9 @@ int MXSymbolInferShape(SymbolHandle sym, mx_uint *out_shape_size, const mx_uint **out_shape_ndim, const mx_uint ***out_shape_data, + mx_uint *aux_shape_size, + const mx_uint **aux_shape_ndim, + const mx_uint ***aux_shape_data, int *complete) { Symbol *s = static_cast(sym); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); @@ -617,26 +636,31 @@ int MXSymbolInferShape(SymbolHandle sym, ret->arg_shapes.push_back(TShape(arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1])); } - succ = s->InferShape(&(ret->arg_shapes), &(ret->out_shapes)); + succ = s->InferShape(&(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { kwargs[keys[i]] = TShape(arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]); } - succ = s->InferShape(kwargs, &(ret->arg_shapes), &(ret->out_shapes)); + succ = s->InferShape(kwargs, &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes)); } if (succ) { MXAPIThreadLocalEntry::SetupShapeArrayReturn( ret->arg_shapes, &(ret->arg_shape_ndim), &(ret->arg_shape_data)); MXAPIThreadLocalEntry::SetupShapeArrayReturn( ret->out_shapes, &(ret->out_shape_ndim), &(ret->out_shape_data)); + MXAPIThreadLocalEntry::SetupShapeArrayReturn( + ret->aux_shapes, &(ret->aux_shape_ndim), &(ret->aux_shape_data)); *in_shape_size = static_cast(ret->arg_shapes.size()); *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim); *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data); *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim); *out_shape_data = dmlc::BeginPtr(ret->out_shape_data); + *aux_shape_size = static_cast(ret->aux_shapes.size()); + *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim); + *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data); *complete = 1; } else { *complete = 0; @@ -644,10 +668,10 @@ int MXSymbolInferShape(SymbolHandle sym, API_END(); } -int MXExecutorForward(ExecutorHandle handle) { +int MXExecutorForward(ExecutorHandle handle, bool is_train) { API_BEGIN(); Executor *exec = static_cast(handle); - exec->Forward(); + exec->Forward(is_train); API_END(); } @@ -690,21 +714,28 @@ int MXExecutorBind(SymbolHandle symbol_handle, NArrayHandle *in_args, NArrayHandle *arg_grad_store, mx_uint *grad_req_type, + mx_uint aux_states_len, + NArrayHandle *aux_states, ExecutorHandle *out) { API_BEGIN(); Symbol *symb = static_cast(symbol_handle); Context ctx = Context(dev_mask, dev_id); NArray **in_args_ptr = reinterpret_cast(in_args); NArray **arg_grad_ptr = reinterpret_cast(arg_grad_store); + NArray **aux_states_ptr = reinterpret_cast(aux_states); std::vector in_args_vec; std::vector arg_grad_vec; std::vector grad_req_vec; + std::vector aux_states_vec; for (mx_uint i = 0; i < len; ++i) { in_args_vec.push_back(*(in_args_ptr[i])); arg_grad_vec.push_back(*(arg_grad_ptr[i])); grad_req_vec.push_back(static_cast(grad_req_type[i])); } - *out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec); + for (mx_uint i = 0; i < aux_states_len; ++i) { + aux_states_vec.push_back(*(aux_states_ptr[i])); + } + *out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec, aux_states_vec); API_END(); } diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h old mode 100644 new mode 100755 index 7315d908aa0d..43aa4f01637a --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -47,7 +47,8 @@ class ActivationOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); @@ -63,7 +64,8 @@ class ActivationOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); @@ -84,14 +86,13 @@ Operator* CreateOp(ActivationParam type); #if DMLC_USE_CXX11 class ActivationProp : public OperatorProperty { public: - virtual void Init(const std::vector >& kwargs) { - // TODO(bing) change directly to vector of pairs begin end - std::map kmap(kwargs.begin(), kwargs.end()); - param_.Init(kmap); + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; const TShape &dshape = in_shape->at(0); @@ -101,35 +102,35 @@ class ActivationProp : public OperatorProperty { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { auto ptr = new ActivationProp(); ptr->param_ = param_; return ptr; } - virtual std::string TypeString() const { + std::string TypeString() const override { return "Activation"; } // decalre dependency and inplace optimization options - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {out_grad[kOut], out_data[kOut]}; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{out_grad[kOut], in_grad[kData]}}; } - virtual std::vector > ForwardInplaceOption( + std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {{in_data[kData], out_data[kOut]}}; } diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h new file mode 100755 index 000000000000..0f3b303f85b6 --- /dev/null +++ b/src/operator/batch_norm-inl.h @@ -0,0 +1,271 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file batch_norm-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_BATCH_NORM_INL_H_ +#define MXNET_OPERATOR_BATCH_NORM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { +enum BatchNormOpInputs {kData, kGamma, kBeta}; +enum BatchNormOpOutputs {kOut, kOutNoAffine, kMean, kVar}; +enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; + +struct BatchNormParam : public dmlc::Parameter { + float eps; + float momentum; + DMLC_DECLARE_PARAMETER(BatchNormParam) { + DMLC_DECLARE_FIELD(eps).set_default(1e-10f) + .describe("Epsilon to prevent div 0"); + DMLC_DECLARE_FIELD(momentum).set_default(0.1f) + .describe("Momentum for moving average"); + } +}; + +template +class BatchNormOp : public Operator { + public: + explicit BatchNormOp(BatchNormParam param) : is_init(false) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(aux_states.size(), 2); + if (ctx.is_train) { + CHECK_EQ(out_data.size(), 4); + CHECK_EQ(req.size(), 4); + } else { + CHECK_GE(out_data.size(), 1); + CHECK_GE(req.size(), 1); + CHECK_EQ(req[kOut], kWriteTo); + } + + Stream *s = ctx.get_stream(); + const real_t scale = static_cast(in_data[kData].shape_[1]) / + static_cast(in_data[kData].shape_.Size()); + Tensor data; + Tensor out, out_no_affine; + if (in_data[kData].ndim() == 2) { + uint32_t ds[] = {in_data[kData].shape_[0], in_data[kData].shape_[1], 1, 1}; + TShape dshape(ds, ds + 4); + data = in_data[kData].get_with_shape(dshape, s); + out = out_data[kOut].get_with_shape(dshape, s); + if (ctx.is_train) { + out_no_affine = out_data[kOutNoAffine].get_with_shape(dshape, s); + } + } else { + data = in_data[kData].get(s); + out = out_data[kOut].get(s); + if (ctx.is_train) { + out_no_affine = out_data[kOutNoAffine].get(s); + } + } + Tensor slope = in_data[kGamma].get(s); + Tensor bias = in_data[kBeta].get(s); + Tensor moving_mean = aux_states[kMovingMean].get(s); + Tensor moving_var = aux_states[kMovingVar].get(s); + // cal + if (ctx.is_train) { + Tensor mean = out_data[kMean].get(s); + Tensor var = out_data[kVar].get(s); + Assign(mean, req[kMean], scale * sumall_except_dim<1>(data)); + Assign(var, req[kVar], scale * sumall_except_dim<1>( + F(data - broadcast<1>(mean, data.shape_)))); + Assign(out_no_affine, req[kOutNoAffine], (data - broadcast<1>(mean, data.shape_)) / + F(broadcast<1>(var + param_.eps, data.shape_))); + Assign(out, req[kOut], out_no_affine * broadcast<1>(slope, out.shape_) + + broadcast<1>(bias, out.shape_)); + moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); + moving_var = moving_var * param_.momentum + var * (1 - param_.momentum); + } else { + Assign(out, req[kOut], broadcast<1>(slope / + F(moving_var + param_.eps), data.shape_) * data + + broadcast<1>(bias - (slope * moving_mean) / + F(moving_var + param_.eps), data.shape_)); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(out_data.size(), 4); + CHECK_EQ(in_grad.size(), 3); + Stream *s = ctx.get_stream(); + Tensor data, grad, grad_in; + Tensor out, out_no_affine; + const real_t scale = static_cast(out_data[kOut].shape_[1]) / + static_cast(out_data[kOut].shape_.Size()); + if (in_data[kData].ndim() == 2) { + uint32_t ds[] = {out_data[kOut].shape_[0], out_data[kOut].shape_[1], 1, 1}; + TShape dshape(ds, ds + 4); + data = in_data[kData].get_with_shape(dshape, s); + grad = out_grad[kOut].get_with_shape(dshape, s); + grad_in = in_grad[kData].get_with_shape(dshape, s); + out = out_data[kOut].get_with_shape(dshape, s); + out_no_affine = out_data[kOutNoAffine].get_with_shape(dshape, s); + } else { + data = in_data[kData].get(s); + grad = out_grad[kOut].get(s); + grad_in = in_grad[kData].get(s); + out = out_data[kOut].get(s); + out_no_affine = out_data[kOutNoAffine].get(s); + } + this->Init(ctx, out.shape_); + Tensor mean = out_data[kMean].get(s); + Tensor var = out_data[kVar].get(s); + Tensor slope = in_data[kGamma].get(s); + // Tensor bias = in_data[kBeta].get(s); + Tensor gslope = in_grad[kGamma].get(s); + Tensor gbias = in_grad[kBeta].get(s); + Tensor gmean = tmp_[0]; + Tensor gvar = tmp_[1]; + Tensor tmp = tmp_[2]; + // cal + gvar = sumall_except_dim<1>((grad * broadcast<1>(slope, data.shape_)) * + (data - broadcast<1>(mean, data.shape_)) * + -0.5f * + F(broadcast<1>(var + param_.eps, data.shape_), -1.5f)); + gmean = sumall_except_dim<1>(grad * broadcast<1>(slope, data.shape_)); + gmean *= -1.0f / F(var + param_.eps); + tmp = scale * sumall_except_dim<1>(-2.0f * (data - broadcast<1>(mean, data.shape_))); + tmp *= gvar; + gmean += tmp; + // assign + Assign(gslope, req[kGamma], sumall_except_dim<1>(grad * out_no_affine)); + Assign(gbias, req[kBeta], sumall_except_dim<1>(grad)); + Assign(grad_in, req[kData], (grad * broadcast<1>(slope, data.shape_)) * + broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + + broadcast<1>(gvar, data.shape_) * scale * 2.0f * (data - broadcast<1>(mean, data.shape_)) + + broadcast<1>(gmean, data.shape_) * scale); + } + + private: + // TODO(bing): use global memory allocator + inline void Init(const OpContext &ctx, + const mshadow::Shape<4> &dshape) { + if (is_init) return; + is_init = true; + mshadow::Stream *s = ctx.get_stream(); + tmp_.set_stream(s); + tmp_.Resize(mshadow::Shape2(3, dshape[1])); + } + mshadow::TensorContainer tmp_; + BatchNormParam param_; + bool is_init; +}; // class BatchNormOp + +template +Operator *CreateOp(BatchNormParam param); + + +#if DMLC_USE_CXX11 +class BatchNormProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3) << "Input:[data, gamma, beta]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + in_shape->at(1) = TShape(Shape1(dshape[1])); + in_shape->at(2) = TShape(Shape1(dshape[1])); + out_shape->clear(); + out_shape->push_back(dshape); + out_shape->push_back(dshape); + out_shape->push_back(Shape1(dshape[1])); + out_shape->push_back(Shape1(dshape[1])); + aux_shape->clear(); + aux_shape->push_back(Shape1(dshape[1])); + aux_shape->push_back(Shape1(dshape[1])); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new BatchNormProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "BatchNorm"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[kOut], + out_data[kOut], out_data[kOutNoAffine], out_data[kMean], out_data[kVar], + in_data[kData], in_data[kGamma], in_data[kBeta]}; + } + + std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const override { + return {{out_grad[kOut], in_grad[kData]}}; + } + + int NumVisibleReturns() const override { + return 1; + } + + int NumReturns() const override { + return 4; + } + + std::vector ListArguments() const override { + return {"data", "gamma", "beta"}; + } + + std::vector ListReturns() const override { + return {"output", "output_no_affine", "mean", "var"}; + } + + std::vector ListAuxiliaryStates() const override { + return {"moving_mean", "moving_var"}; + } + + Operator* CreateOperator(Context ctx) const; + + private: + BatchNormParam param_; +}; // class BatchNormProp + +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_BATCH_NORM_INL_H_ diff --git a/src/operator/batch_norm.cc b/src/operator/batch_norm.cc new file mode 100644 index 000000000000..0c11b06fb9bc --- /dev/null +++ b/src/operator/batch_norm.cc @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file batch_norm.cc + * \brief + * \author Bing Xu +*/ + +#include "./batch_norm-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(BatchNormParam param) { + return new BatchNormOp(param); +} + +Operator *BatchNormProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(BatchNormParam); + +MXNET_REGISTER_OP_PROPERTY(BatchNorm, BatchNormProp) +.describe("Apply batch normalization to input.") +.add_argument("data", "Symbol", "Input data to batch normalization") +.add_arguments(BatchNormParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/batch_norm.cu b/src/operator/batch_norm.cu new file mode 100644 index 000000000000..6f7e04b9f171 --- /dev/null +++ b/src/operator/batch_norm.cu @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file batch_norm.cu + * \brief + * \author Bing Xu +*/ + +#include "./batch_norm-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(BatchNormParam param) { + return new BatchNormOp(param); +} + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h old mode 100644 new mode 100755 index 22e1e6f0d7e1..71311cc4b211 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -61,7 +61,8 @@ class ConvolutionOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[kOut], kWriteTo); @@ -77,7 +78,7 @@ class ConvolutionOp : public Operator { TShape wmat_shape(ws, ws + 3); Tensor wmat = in_data[kWeight].get_with_shape(wmat_shape, s); Tensor out = out_data[kOut].get(s); - this->InitTemp(data.shape_, out.shape_); + this->InitTemp(ctx, data.shape_, out.shape_); const index_t nbatch = data.size(0); for (index_t i = 0; i < nbatch; i += param_.nstep) { const index_t step = std::min(param_.nstep, nbatch - i); @@ -124,7 +125,8 @@ class ConvolutionOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; // TODO(bing): check the BLAS Handle, be careful @@ -144,7 +146,7 @@ class ConvolutionOp : public Operator { Tensor grad = out_grad[kOut].get(s); Tensor gdata = in_grad[kData].get(s); Tensor gwmat = in_grad[kWeight].get_with_shape(wmat_shape, s); - this->InitTemp(data.shape_, grad.shape_); + this->InitTemp(ctx, data.shape_, grad.shape_); const index_t nbatch = data.size(0); for (index_t i = 0; i < nbatch; i += param_.nstep) { const index_t step = std::min(param_.nstep, nbatch - i); @@ -208,7 +210,8 @@ class ConvolutionOp : public Operator { private: // TODO(bing): use global resource allocator - inline void InitTemp(const mshadow::Shape<4> &ishape, + inline void InitTemp(const OpContext &ctx, + const mshadow::Shape<4> &ishape, const mshadow::Shape<4> &oshape) { const int ksize_y = param_.kernel[0]; const int ksize_x = param_.kernel[1]; @@ -219,6 +222,9 @@ class ConvolutionOp : public Operator { oshape[2] * oshape[3]); int nop = (ishape[0] + param_.nstep - 1) / param_.nstep; param_.nstep = (ishape[0] + nop - 1) / nop; + mshadow::Stream *s = ctx.get_stream(); + temp_col_.set_stream(s); + temp_dst_.set_stream(s); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * param_.nstep)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], @@ -240,7 +246,7 @@ Operator* CreateOp(ConvolutionParam param); #if DMLC_USE_CXX11 class ConvolutionProp : public OperatorProperty { public: - virtual std::vector ListArguments() const { + std::vector ListArguments() const override { if (!param_.no_bias) { return {"data", "weight", "bias"}; } else { @@ -248,12 +254,13 @@ class ConvolutionProp : public OperatorProperty { } } - virtual void Init(const std::vector >& kwargs) { + void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; @@ -292,28 +299,28 @@ class ConvolutionProp : public OperatorProperty { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { auto ptr = new ConvolutionProp(); ptr->param_ = param_; return ptr; } - virtual std::string TypeString() const { + std::string TypeString() const override { return "Convolution"; } - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {out_grad[kOut], in_data[kData], in_data[kWeight]}; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{in_data[kData], in_grad[kData]}}; } diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h old mode 100644 new mode 100755 index d4b28eb43f4c..390df9bd36e1 --- a/src/operator/elementwise_sum-inl.h +++ b/src/operator/elementwise_sum-inl.h @@ -40,7 +40,8 @@ class ElementWiseSumOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(static_cast(in_data.size()), size_); @@ -86,7 +87,8 @@ class ElementWiseSumOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), static_cast(size_)); @@ -110,14 +112,15 @@ Operator* CreateOp(ElementWiseSumParam param); #if DMLC_USE_CXX11 class ElementWiseSumProp : public OperatorProperty { public: - virtual void Init(const std::vector >& kwargs) { + void Init(const std::vector >& kwargs) override { // TODO(bing) change directly to vector of pairs begin end std::map kmap(kwargs.begin(), kwargs.end()); param_.Init(kmap); } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), static_cast(param_.size)); const TShape &dshape = in_shape->at(0); @@ -130,34 +133,34 @@ class ElementWiseSumProp : public OperatorProperty { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { auto ptr = new ElementWiseSumProp(); ptr->param_ = param_; return ptr; } - virtual std::string TypeString() const { + std::string TypeString() const override { return "ElementWiseSum"; } - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return out_grad; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{out_grad[0], in_grad[0]}}; } - virtual std::vector > ForwardInplaceOption( + std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {{in_data[0], out_data[0]}}; } diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h old mode 100644 new mode 100755 index ac5fd992cd82..ece1131f7717 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -25,12 +25,12 @@ enum FullyConnectedOpInputs {kData, kWeight, kBias}; enum FullyConnectedOpOutputs {kOut}; struct FullyConnectedParam : public dmlc::Parameter { - int num_hidden; + int nb_hidden; bool no_bias; DMLC_DECLARE_PARAMETER(FullyConnectedParam) { // TODO(bing) change to only set lower bound // add support for boolean - DMLC_DECLARE_FIELD(num_hidden).set_range(1, 100000) + DMLC_DECLARE_FIELD(nb_hidden).set_range(1, 100000) .describe("Number of hidden nodes of the output."); DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); @@ -51,7 +51,8 @@ class FullyConnectedOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[kOut], kWriteTo); @@ -77,7 +78,8 @@ class FullyConnectedOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); @@ -116,7 +118,7 @@ Operator* CreateOp(FullyConnectedParam param); #if DMLC_USE_CXX11 class FullyConnectedProp : public OperatorProperty { public: - virtual std::vector ListArguments() const { + std::vector ListArguments() const override { if (!param_.no_bias) { return {"data", "weight", "bias"}; } else { @@ -124,12 +126,13 @@ class FullyConnectedProp : public OperatorProperty { } } - virtual void Init(const std::vector >& kwargs) { + void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; @@ -143,37 +146,37 @@ class FullyConnectedProp : public OperatorProperty { index_t num_input = 0; mshadow::Shape<2> ishape = dshape.FlatTo2D(); num_input = ishape[1]; - SHAPE_ASSIGN_CHECK(*in_shape, kWeight, Shape2(param_.num_hidden, num_input)); + SHAPE_ASSIGN_CHECK(*in_shape, kWeight, Shape2(param_.nb_hidden, num_input)); if (!param_.no_bias) { - SHAPE_ASSIGN_CHECK(*in_shape, kBias, Shape1(param_.num_hidden)); + SHAPE_ASSIGN_CHECK(*in_shape, kBias, Shape1(param_.nb_hidden)); } out_shape->clear(); - out_shape->push_back(Shape2(dshape[0], param_.num_hidden)); + out_shape->push_back(Shape2(dshape[0], param_.nb_hidden)); return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { FullyConnectedProp* fc_sym = new FullyConnectedProp(); fc_sym->param_ = this->param_; return fc_sym; } - virtual std::string TypeString() const { + std::string TypeString() const override { return "FullyConnecteded"; } // decalre dependency and inplace optimization options - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {out_grad[kOut], in_data[kData], in_data[kWeight]}; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{in_data[kData], in_grad[kData]}}; } diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h old mode 100644 new mode 100755 index df6224656e31..491d21dbe810 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -63,7 +63,8 @@ class PoolingOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); @@ -98,7 +99,8 @@ class PoolingOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); @@ -152,12 +154,13 @@ Operator* CreateOp(PoolingParam param); #if DMLC_USE_CXX11 class PoolingProp : public OperatorProperty { public: - virtual void Init(const std::vector >& kwargs) { + void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1); const TShape &dshape = (*in_shape)[0]; CHECK_EQ(dshape.ndim(), 4) << \ @@ -174,28 +177,28 @@ class PoolingProp : public OperatorProperty { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { PoolingProp *prop_sym = new PoolingProp(); prop_sym->param_ = this->param_; return prop_sym; } - virtual std::string TypeString() const { + std::string TypeString() const override { return "Pooling"; } - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {out_grad[kOut], in_data[kData], out_data[kOut]}; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{in_data[kData], in_grad[kData]}}; } diff --git a/src/operator/reshape-inl.h b/src/operator/reshape-inl.h old mode 100644 new mode 100755 index 68918c460678..8bd95c49927e --- a/src/operator/reshape-inl.h +++ b/src/operator/reshape-inl.h @@ -36,7 +36,8 @@ class ReshapeOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); @@ -44,6 +45,7 @@ class ReshapeOp : public Operator { CHECK_EQ(out_data.size(), 1); if (req[kOut] == kNullOp) return; Stream *s = ctx.get_stream(); + // TODO(bing): potentail bug here for non-4D input Tensor data = in_data[kData].get(s); Tensor out = out_data[kOut].get(s); CHECK_EQ(data.CheckContiguous(), true); @@ -58,7 +60,8 @@ class ReshapeOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req.size(), 1); @@ -86,16 +89,17 @@ class ReshapeProp : public OperatorProperty { explicit ReshapeProp(ReshapeParam param) : param_(param) {} - virtual void Init(const std::vector >& kwargs) { + void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } - virtual std::string TypeString() const { + std::string TypeString() const override { return "Reshape"; } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1) << "Input: [data]"; const TShape &dshape = in_shape->at(kData); if (dshape.ndim() == 0) return false; @@ -108,30 +112,30 @@ class ReshapeProp : public OperatorProperty { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { auto ptr = new ReshapeProp(); ptr->param_ = param_; return ptr; } - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {out_grad[kOut]}; } - virtual std::vector > ForwardInplaceOption( + std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {{in_data[kData], out_data[kOut]}}; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{out_grad[kOut], in_grad[kData]}}; } @@ -143,14 +147,15 @@ class ReshapeProp : public OperatorProperty { class FlattenProp : public ReshapeProp { public: - virtual void Init(const std::vector >& kwargs) {} + void Init(const std::vector >& kwargs) override {} - virtual std::string TypeString() const { + std::string TypeString() const override { return "Flatten"; } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1) << "Input: [data]"; const TShape &dshape = in_shape->at(kData); if (dshape.ndim() == 0) return false; @@ -163,7 +168,7 @@ class FlattenProp : public ReshapeProp { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { auto ptr = new FlattenProp(); return ptr; } diff --git a/src/operator/reshape.cc b/src/operator/reshape.cc old mode 100644 new mode 100755 index f8058955708b..6bd077172d4a --- a/src/operator/reshape.cc +++ b/src/operator/reshape.cc @@ -22,12 +22,12 @@ Operator* ReshapeProp::CreateOperator(Context ctx) const { DMLC_REGISTER_PARAMETER(ReshapeParam); MXNET_REGISTER_OP_PROPERTY(Reshape, ReshapeProp) -.add_argument("data", "Symbol", "Input data to flatten.") -.describe("Reshape input to target shape"); +.describe("Reshape input to target shape") +.add_argument("data", "Symbol", "Input data to reshape.") +.add_arguments(ReshapeParam::__FIELDS__()); MXNET_REGISTER_OP_PROPERTY(Flatten, FlattenProp) -.add_argument("data", "Symbol", "Input data to flatten.") -.add_arguments(ReshapeParam::__FIELDS__()) -.describe("Flatten input"); +.describe("Flatten input") +.add_argument("data", "Symbol", "Input data to flatten."); } // namespace op } // namespace mxnet diff --git a/src/operator/softmax-inl.h b/src/operator/softmax-inl.h old mode 100644 new mode 100755 index 097b7beae626..cf4e2671d719 --- a/src/operator/softmax-inl.h +++ b/src/operator/softmax-inl.h @@ -39,7 +39,8 @@ class SoftmaxOp : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2) << "Softmax Input: [data, label]"; @@ -55,7 +56,8 @@ class SoftmaxOp : public Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) { + const std::vector &in_grad, + const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2); @@ -83,16 +85,17 @@ Operator* CreateOp(SoftmaxParam param); #if DMLC_USE_CXX11 class SoftmaxProp : public OperatorProperty { public: - virtual std::vector ListArguments() const { + std::vector ListArguments() const override { return {"data", "label"}; } - virtual void Init(const std::vector >& kwargs) { + void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } - virtual bool InferShape(std::vector *in_shape, - std::vector *out_shape) const { + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; const TShape &dshape = in_shape->at(0); @@ -103,34 +106,34 @@ class SoftmaxProp : public OperatorProperty { return true; } - virtual OperatorProperty* Copy() const { + OperatorProperty* Copy() const override { auto ptr = new SoftmaxProp(); ptr->param_ = param_; return ptr; } - virtual std::string TypeString() const { + std::string TypeString() const override { return "Softmax"; } - virtual std::vector DeclareBackwardDependency( + std::vector DeclareBackwardDependency( const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {in_data[kLabel], out_data[kOut]}; } - virtual std::vector > BackwardInplaceOption( + std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const override { return {{out_data[kOut], in_grad[kData]}}; } - virtual std::vector > ForwardInplaceOption( + std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const override { return {{in_data[kData], out_data[kOut]}}; } diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc old mode 100644 new mode 100755 index 68de552e7f21..aeff3427d8f3 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -46,14 +46,15 @@ class GraphExecutor::BackwardOpWrapper : public Operator { virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, - const std::vector &out_data) { + const std::vector &out_data, + const std::vector &aux_states) { // set things correctly CHECK(arg_data_ptr_.size() == in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { *(arg_data_ptr_[i]) = in_data[i]; } // redirect internally - op_->Backward(ctx, out_grad_, in_data_, out_data_, req, out_data); + op_->Backward(ctx, out_grad_, in_data_, out_data_, req, out_data, aux_states); } private: @@ -170,18 +171,25 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { OpNode& op_node = op_nodes_[nid]; Operator *op = op_node.op.get(); std::vector req; - std::vector in_data, out_data; + std::vector in_data, out_data, aux_states; in_data.reserve(graph_.nodes[nid].inputs.size()); out_data.reserve(op_node.outputs.size()); req.reserve(op_node.outputs.size()); + aux_states.reserve(op_node.aux_states.size()); OpExecEntry exec; + // output for (const DataEntryInfo& out : op_node.outputs) { out_data.push_back(out.data.data()); exec.mutate_vars.push_back(out.data.var()); req.push_back(out.op_req); } - + // aux + for (const DataEntryInfo& aux : op_node.aux_states) { + aux_states.push_back(aux.data.data()); + exec.mutate_vars.push_back(aux.data.var()); + } + // input for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { const DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; in_data.push_back(info.data.data()); @@ -192,9 +200,9 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { } OpContext* op_ctx_ptr = &op_node.op_ctx; - exec.exec_fun = [op, op_ctx_ptr, in_data, req, out_data] (RunContext ctx) { + exec.exec_fun = [op, op_ctx_ptr, in_data, req, out_data, aux_states] (RunContext ctx) { op_ctx_ptr->run_ctx = ctx; - op->Forward(*op_ctx_ptr, in_data, req, out_data); + op->Forward(*op_ctx_ptr, in_data, req, out_data, aux_states); }; return exec; } @@ -228,7 +236,8 @@ void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, const std::vector &arg_grad_store, - const std::vector &grad_req_type) { + const std::vector &grad_req_type, + const std::vector &aux_states) { CHECK_EQ(arg_grad_store.size(), grad_req_type.size()); CHECK_EQ(in_args.size(), graph_.arg_nodes.size()); // bind inputs @@ -280,19 +289,37 @@ void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, // shape inference std::vector > out_shapes(op_nodes_.size()); + std::vector > aux_shapes(op_nodes_.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { out_shapes[i].resize(op_nodes_[i].outputs.size()); } for (size_t i = 0; i < graph_.arg_nodes.size(); ++i) { out_shapes[graph_.arg_nodes[i]][0] = in_args[i].shape(); } - CHECK(graph_.InferNodeShapes(topo_order_, &out_shapes)) + CHECK(graph_.InferNodeShapes(topo_order_, &out_shapes, &aux_shapes)) << "Shape inference cannot be complete in bind"; for (size_t i = 0; i < out_shapes.size(); ++i) { for (size_t j = 0; j < out_shapes[i].size(); ++j) { op_nodes_[i].outputs[j].shape = out_shapes[i][j]; } } + // bind aux args + size_t aux_narray_idx = 0; + for (size_t i = 0; i < aux_shapes.size(); ++i) { + op_nodes_[i].aux_states.resize(aux_shapes[i].size()); + for (size_t j = 0; j < aux_shapes[i].size(); ++j) { + DataEntryInfo &info = op_nodes_[i].aux_states[j]; + info.shape = aux_shapes[i][j]; + info.type = kBindByExternal; + CHECK_GT(aux_states.size(), aux_narray_idx) + << "Input auxiliary NArray is less than required"; + info.data = aux_states[aux_narray_idx++]; + CHECK_EQ(info.data.data().shape_, info.shape) + << "Incorrect NArray shape" + << " Input: " << info.data.data().shape_ + << " Desired: " << info.shape; + } + } } void GraphExecutor::InitDataEntryMemory() { @@ -415,12 +442,13 @@ void GraphExecutor::InitOpNodes() { } } -void GraphExecutor::RunOps(size_t topo_start, size_t topo_end) { +void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { for (size_t i = topo_start; i < topo_end; ++i) { uint32_t nid = topo_order_[i]; if (!op_nodes_[nid].activated) continue; if (graph_.nodes[nid].is_variable()) continue; OpNode& opnode = op_nodes_[nid]; + opnode.op_ctx.is_train = is_train; if (opnode.cached_exec.exec_fun != nullptr) { DAGEngine::Get()->Push( opnode.cached_exec.exec_fun, @@ -460,8 +488,8 @@ std::string GraphExecutor::DebugStr() const { return os.str(); } -void GraphExecutor::Forward() { - RunOps(0, num_forward_nodes_); +void GraphExecutor::Forward(bool is_train) { + RunOps(is_train, 0, num_forward_nodes_); } void GraphExecutor::Backward(const std::vector &head_grads) { @@ -473,16 +501,17 @@ void GraphExecutor::Backward(const std::vector &head_grads) { CHECK_EQ(info.type, kTobeBindByExternal); info.data = head_grads[i]; } - RunOps(num_forward_nodes_, topo_order_.size()); + RunOps(true, num_forward_nodes_, topo_order_.size()); } Executor *Executor::Bind(Symbol symbol, Context ctx, const std::vector &in_args, const std::vector &arg_grad_store, - const std::vector &grad_req_type) { + const std::vector &grad_req_type, + const std::vector &aux_states) { GraphExecutor *exec = new GraphExecutor(); - exec->Init(symbol, ctx, in_args, arg_grad_store, grad_req_type); + exec->Init(symbol, ctx, in_args, arg_grad_store, grad_req_type, aux_states); return exec; } } // namespace mxnet diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h old mode 100644 new mode 100755 index f74d73ec8e44..66cd074b406b --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -20,7 +20,7 @@ namespace mxnet { class GraphExecutor : public Executor { public: virtual ~GraphExecutor() {} - virtual void Forward(); + virtual void Forward(bool is_train); virtual void Backward(const std::vector &head_grads); virtual const std::vector &heads() const { return heads_narray_; @@ -30,14 +30,15 @@ class GraphExecutor : public Executor { Context ctx, const std::vector &in_args, const std::vector &arg_grad_store, - const std::vector &grad_req_type) { + const std::vector &grad_req_type, + const std::vector &aux_states) { CHECK_EQ(grad_req_type.size(), arg_grad_store.size()); bool need_backward = false; for (auto req : grad_req_type) { if (req != kNullOp) need_backward = true; } this->InitGraph(symbol, ctx, need_backward); - this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type); + this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type, aux_states); this->InitDataEntryMemory(); this->InitOpNodes(); // TODO(bing): remove me when things are OK @@ -106,6 +107,8 @@ class GraphExecutor : public Executor { Context ctx; // data entry information about outputs of op std::vector outputs; + // auxiliary data information of op + std::vector aux_states; // The following parts are constructed in InitOpNodes // the real operator std::shared_ptr op; @@ -158,13 +161,14 @@ class GraphExecutor : public Executor { // initialize internal DataEntryInfo, reference counting void InitDataEntryInfo(const std::vector &in_args, const std::vector &arg_grad_store, - const std::vector &grad_req_type); + const std::vector &grad_req_type, + const std::vector &aux_states); // initialize internal data entries NArray void InitDataEntryMemory(); // initialize OpNode data structure void InitOpNodes(); // run ops from topo order start to end - void RunOps(size_t topo_start, size_t topo_end); + void RunOps(bool is_train, size_t topo_start, size_t topo_end); // get debug string std::string DebugStr() const; // internal computational graph diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc old mode 100644 new mode 100755 index 53213e2aced7..05c8785de0c8 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -50,7 +50,8 @@ std::vector StaticGraph::TopoSort() const { } bool StaticGraph::InferNodeShapes(const std::vector &topo_order, - std::vector > *node_out_shapes) const { + std::vector > *node_out_shapes, + std::vector > *node_aux_shapes) const { for (uint32_t nid : topo_order) { const Node& node = nodes[nid]; if (node.is_forward()) { @@ -59,7 +60,9 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, in_shape.push_back((*node_out_shapes)[e.source_id][e.index]); } try { - if (!node.op->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; + if (!node.op->InferShape(&in_shape, + &(*node_out_shapes)[nid], + &(*node_aux_shapes)[nid])) return false; } catch (const op::InferShapeError &err) { // error handling const std::string &op_name = node.name; @@ -123,8 +126,10 @@ bool StaticGraph::InferNodeShapes(const std::vector &topo_order, } bool StaticGraph::InferShape(std::vector *in_shape, - std::vector *out_shape) const { + std::vector *out_shape, + std::vector *aux_shape) const { std::vector > node_out_shapes(nodes.size()); + std::vector > node_aux_shapes(nodes.size()); for (size_t i = 0; i < nodes.size(); ++i) { int nout = 1; if (nodes[i].is_forward()) { @@ -140,7 +145,8 @@ bool StaticGraph::InferShape(std::vector *in_shape, node_out_shapes[arg_nodes[i]][0] = (*in_shape)[i]; } if (!InferNodeShapes(this->TopoSort(), - &node_out_shapes)) return false; + &node_out_shapes, + &node_aux_shapes)) return false; for (size_t i = 0; i < arg_nodes.size(); ++i) { (*in_shape)[i] = node_out_shapes[arg_nodes[i]][0]; } @@ -149,6 +155,13 @@ bool StaticGraph::InferShape(std::vector *in_shape, const DataEntry &e = heads[i]; (*out_shape)[i] = node_out_shapes[e.source_id][e.index]; } + for (size_t i = 0; i < node_aux_shapes.size(); ++i) { + if (node_aux_shapes[i].size() > 0) { + for (auto const &shape : node_aux_shapes[i]) { + aux_shape->push_back(shape); + } + } + } return true; } diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc old mode 100644 new mode 100755 index 26ecb9691ede..ddc7d96556e6 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -211,6 +211,27 @@ std::vector Symbol::ListReturns() const { return ret; } +std::vector Symbol::ListAuxiliaryStates() const { + // TODO(linmin, bing): better solution + std::vector ret; + StaticGraph g; + this->ToStaticGraph(&g); + std::vector topo_order = g.TopoSort(); + for (uint32_t nid : topo_order) { + const auto& node = g.nodes[nid]; + if (node.op != nullptr) { + auto aux_args = node.op->ListAuxiliaryStates(); + if (aux_args.size() > 0) { + auto &hname = node.name; + for (auto const &aux : aux_args) { + ret.push_back(hname + '_' + aux); + } + } + } + } + return ret; +} + Symbol Symbol::operator[] (size_t index) const { size_t nreturn = NumReturns(); CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; @@ -425,15 +446,17 @@ Symbol Symbol::Grad(const std::vector& wrt) const { } bool Symbol::InferShape(std::vector *arg_shapes, - std::vector *out_shapes) const { + std::vector *out_shapes, + std::vector *aux_shapes) const { StaticGraph g; this->ToStaticGraph(&g); - return g.InferShape(arg_shapes, out_shapes); + return g.InferShape(arg_shapes, out_shapes, aux_shapes); } bool Symbol::InferShape(const std::unordered_map& known_arg_shapes, std::vector *arg_shapes, - std::vector *out_shapes) const { + std::vector *out_shapes, + std::vector *aux_shapes) const { StaticGraph g; this->ToStaticGraph(&g); arg_shapes->clear(); @@ -453,7 +476,7 @@ bool Symbol::InferShape(const std::unordered_map& known_arg [](decltype(*known_arg_shapes.begin())& kv)->std::string { return kv.first; }); KeywordArgumentMismatch("Symbol.InterShape", keys, ListArguments()); } - return g.InferShape(arg_shapes, out_shapes); + return g.InferShape(arg_shapes, out_shapes, aux_shapes); } Symbol Symbol::Create(OperatorProperty *op) { diff --git a/tests/python/models.py b/tests/python/models.py index d7fb74e4fd1e..0709b8d49f8f 100644 --- a/tests/python/models.py +++ b/tests/python/models.py @@ -3,8 +3,8 @@ def mlp2(): data = mx.symbol.Variable('data') - out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) + out = mx.symbol.FullyConnected(data=data, name='fc1', nb_hidden=1000) out = mx.symbol.Activation(data=out, act_type='relu') - out = mx.symbol.FullyConnected(data=out, name='fc2', num_hidden=10) + out = mx.symbol.FullyConnected(data=out, name='fc2', nb_hidden=10) return out diff --git a/tests/python/test_conv.py b/tests/python/test_conv.py index 0604476d4bb5..d488fd9d2ca5 100644 --- a/tests/python/test_conv.py +++ b/tests/python/test_conv.py @@ -12,39 +12,47 @@ def CalAcc(out, label): # symbol net batch_size = 100 data = mx.symbol.Variable('data') -conv1= mx.symbol.Convolution(data = data, name='conv1', nb_filter=32, kernel=(3,3), stride=(1,1), nstep=10) -act1 = mx.symbol.Activation(data = conv1, name='relu1', act_type="relu") +conv1= mx.symbol.Convolution(data = data, name='conv1', nb_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") +act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') -conv2= mx.symbol.Convolution(data = mp1, name='conv2', nb_filter=32, kernel=(3,3), stride=(1,1), nstep=10) -act2 = mx.symbol.Activation(data = conv2, name='relu2', act_type="relu") +conv2= mx.symbol.Convolution(data = mp1, name='conv2', nb_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") +act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max') fl = mx.symbol.Flatten(data = mp2, name="flatten") -fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10) +fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', nb_hidden=10) softmax = mx.symbol.Softmax(data = fc2, name = 'sm') args_list = softmax.list_arguments() # infer shape #data_shape = (batch_size, 784) data_shape = (batch_size, 1, 28, 28) -arg_shapes, out_shapes = softmax.infer_shape(data=data_shape) +arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) arg_narrays = [mx.narray.create(shape) for shape in arg_shapes] grad_narrays = [mx.narray.create(shape) for shape in arg_shapes] +aux_narrays = [mx.narray.create(shape) for shape in aux_shapes] + inputs = dict(zip(args_list, arg_narrays)) np.random.seed(0) # set random weight for name, narray in inputs.items(): if "weight" in name: - narray.numpy[:, :] = np.random.uniform(-0.07, 0.07, narray.numpy.shape) + narray.numpy[:] = np.random.uniform(-0.07, 0.07, narray.numpy.shape) if "bias" in name: narray.numpy[:] = 0.0 + if "gamma" in name: + narray.numpy[:] = 1.0 + if "beta" in name: + narray.numpy[:] = 0.0 req = ['write_to' for i in range(len(arg_narrays))] # bind executer # TODO(bing): think of a better bind interface -executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req) +executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req, aux_narrays) # update out_narray = executor.heads()[0] @@ -87,7 +95,7 @@ def test_mnist(): label = label.numpy.flatten() inputs["data"].numpy[:] = data inputs["sm_label"].numpy[:] = label - executor.forward() + executor.forward(is_train = True) train_acc += CalAcc(out_narray.numpy, label) train_nbatch += 1 grad_narray.numpy[:] = out_narray.numpy @@ -101,7 +109,7 @@ def test_mnist(): data = data.numpy label = label.numpy.flatten() inputs["data"].numpy[:] = data - executor.forward() + executor.forward(is_train = False) val_acc += CalAcc(out_narray.numpy, label) val_nbatch += 1 print("Train Acc: ", train_acc / train_nbatch) @@ -113,3 +121,7 @@ def test_mnist(): assert(acc_train > 0.84) assert(acc_val > 0.96) + +if __name__ == "__main__": + test_mnist() + diff --git a/tests/python/test_inter_shape.py b/tests/python/test_infer_shape.py similarity index 72% rename from tests/python/test_inter_shape.py rename to tests/python/test_infer_shape.py index fa18ff175fbf..b7f1efd75225 100644 --- a/tests/python/test_inter_shape.py +++ b/tests/python/test_infer_shape.py @@ -8,9 +8,9 @@ def test_mlp2_infer_shape(): out = models.mlp2() # infer shape data_shape = (100, 100) - arg_shapes, out_shapes = out.infer_shape(data=data_shape) + arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape) arg_shape_dict = dict(zip(out.list_arguments(), arg_shapes)) - + print(len(aux_shapes)) assert len(out_shapes) == 1 assert out_shapes[0] == (100, 10) true_shapes = {'fc2_bias': (10,), @@ -26,5 +26,8 @@ def test_mlp2_infer_error(): out = models.mlp2() weight_shape= (1, 100) data_shape = (100, 100) - arg_shapes, out_shapes = out.infer_shape(data=data_shape, fc1_weight=weight_shape) + arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape, fc1_weight=weight_shape) +if __name__ == "__main__": + test_mlp2_infer_shape() + test_mlp2_infer_error() diff --git a/tests/python/test_mlp.py b/tests/python/test_mlp.py index 8a84d50536c3..174d5be5ce63 100644 --- a/tests/python/test_mlp.py +++ b/tests/python/test_mlp.py @@ -13,16 +13,16 @@ def CalAcc(out, label): # symbol net batch_size = 100 data = mx.symbol.Variable('data') -fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) +fc1 = mx.symbol.FullyConnected(data = data, name='fc1', nb_hidden=128) act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") -fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) +fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', nb_hidden = 64) act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") -fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) +fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', nb_hidden=10) softmax = mx.symbol.Softmax(data = fc3, name = 'sm') args_list = softmax.list_arguments() # infer shape data_shape = (batch_size, 784) -arg_shapes, out_shapes = softmax.infer_shape(data=data_shape) +arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) arg_narrays = [mx.narray.create(shape) for shape in arg_shapes] grad_narrays = [mx.narray.create(shape) for shape in arg_shapes] inputs = dict(zip(args_list, arg_narrays)) @@ -104,3 +104,5 @@ def test_mlp(): assert(acc_train > 0.98) assert(acc_val > 0.97) +if __name__ == "__main__": + test_mlp() diff --git a/tests/python/test_symbol.py b/tests/python/test_symbol.py index b08f6a310570..2682d80957fc 100644 --- a/tests/python/test_symbol.py +++ b/tests/python/test_symbol.py @@ -11,15 +11,15 @@ def test_symbol_basic(): def test_compose(): data = mx.symbol.Variable('data') - net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) - net1 = mx.symbol.FullyConnected(data=net1, name='fc2', num_hidden=100) + net1 = mx.symbol.FullyConnected(data=data, name='fc1', nb_hidden=10) + net1 = mx.symbol.FullyConnected(data=net1, name='fc2', nb_hidden=100) net1.list_arguments() == ['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias'] - net2 = mx.symbol.FullyConnected(name='fc3', num_hidden=10) + net2 = mx.symbol.FullyConnected(name='fc3', nb_hidden=10) net2 = mx.symbol.Activation(data=net2) - net2 = mx.symbol.FullyConnected(data=net2, name='fc4', num_hidden=20) + net2 = mx.symbol.FullyConnected(data=net2, name='fc4', nb_hidden=20) print(net2.debug_str()) composed = net2(fc3_data=net1, name='composed')