diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 10a7c40eea03..1007234a7974 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -5,7 +5,7 @@ import sys sys.path.append("../../tests/python") import get_data - +import time """ CXXNET Result: @@ -70,8 +70,8 @@ def ConvFactory(**kwargs): param = copy.copy(kwargs) act = param["act_type"] del param["act_type"] + param["workspace"] = 512 param["name"] = "conv%d" % conv_cnt - param["nstep"] = 64 conv = mx.symbol.Convolution(**param) bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt) relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act) @@ -89,13 +89,11 @@ def DownsampleFactory(data, ch_3x3, stride = 2): param["num_filter"] = ch_3x3 param["act_type"] = "relu" param["data"] = data - param["nstep"] = 100 param["pad"] = (1, 1) conv3x3 = ConvFactory(**param) # pool del param["num_filter"] del param["act_type"] - del param["nstep"] del param["pad"] param["pool_type"] = "max" param["name"] = "pool%d" % pool_cnt @@ -117,7 +115,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3): param["stride"] = (1, 1) param["act_type"] = "relu" param["data"] = data - param["nstep"] = 128 conv1x1 = ConvFactory(**param) # 3x3 @@ -143,7 +140,7 @@ def RandomInit(narray): in3a = SimpleFactory(conv1, 32, 32) in3b = SimpleFactory(in3a, 32, 48) in3c = DownsampleFactory(in3b, 80) -in4a = SimpleFactory(in3c, 112, 38) +in4a = SimpleFactory(in3c, 112, 48) in4b = SimpleFactory(in4a, 96, 64) in4c = SimpleFactory(in4b, 80, 80) in4d = SimpleFactory(in4c, 48, 96) @@ -155,27 +152,30 @@ def RandomInit(narray): fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") loss = mx.symbol.Softmax(data=fc, name="sm") -args_list = loss.list_arguments() +epoch = 9 +lr = 0.05 +wd = 0.0001 +momentum = 0.9 batch_size = 128 data_shape = (batch_size, 3, 28, 28) -arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape) -arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -mom_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes] -aux_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in aux_shapes] +in_data = mx.narray.empty(data_shape, mx.Context('gpu')) +executor, executor_data = loss.simple_bind(mx.Context('gpu'), {"data": in_data}) +out_narray = executor.heads()[0] +pred = mx.narray.zeros(out_narray.shape) -inputs = dict(zip(args_list, arg_narrays)) +inputs = dict(zip(loss.list_arguments(), executor_data["args"])) +block = list(zip(executor_data["grad"], + executor_data["args"], + executor_data["momentum"])) -name2shape = dict(zip(args_list, arg_shapes)) -pred = mx.narray.zeros(out_shapes[0]) np.random.seed(0) # set random weight -for name, narray in inputs.items(): +for name, narray in zip(loss.list_arguments(), executor_data["args"]): if "weight" in name: narray[:] = np.random.uniform(-0.1, 0.1, narray.shape) if "bias" in name: @@ -185,25 +185,11 @@ def RandomInit(narray): if "beta" in name: narray[:] = 0.0 -# bind executer -# TODO(bing): think of a better bind interface -executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays, 'write', aux_narrays) -# update - -out_narray = executor.heads()[0] - -epoch = 9 -lr = 0.05 -wd = 0.0001 -momentum = 0.9 - def Update(grad, weight, mom): mom[:] *= momentum mom[:] += -lr * (grad / batch_size + wd * weight) weight[:] += mom -block = list(zip(grad_narrays, arg_narrays, mom_narrays)) - #check data get_data.GetCifar10() @@ -224,15 +210,17 @@ def Update(grad, weight, mom): batch_size=batch_size, nthread=1) -tmp_label = mx.narray.zeros(name2shape["sm_label"]) +tmp_label = mx.narray.zeros(inputs["sm_label"].shape) -def progress(count, total, suffix=''): - bar_len = 80 +def progress(count, total, epoch, toc): + bar_len = 60 filled_len = int(round(bar_len * count / float(total))) percents = round(100.0 * count / float(total), 1) bar = '=' * filled_len + '-' * (bar_len - filled_len) - + tic = time.time() + speed = batch_size / float(tic - toc) + suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) def test_cifar(): @@ -247,7 +235,7 @@ def test_cifar(): val_nbatch = 0 all_train_bacth = 50000 / float(batch_size) for data, label in train_dataiter: - progress(train_nbatch, all_train_bacth, "Epoch %d" % i) + toc = time.time() label = label.asnumpy().flatten() tmp_label[:] = label inputs["data"][:] = data @@ -260,6 +248,7 @@ def test_cifar(): for grad, weight, mom in block: Update(grad, weight, mom) + progress(train_nbatch, all_train_bacth, i, toc) # evaluate for data, label in test_dataiter: diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 64c1515a6f3b..cc02293bb20e 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -8,6 +8,7 @@ #define MXNET_OPERATOR_H_ #include +#include #include #include #include @@ -385,6 +386,9 @@ class OperatorProperty { * \return a new constructed OperatorProperty */ static OperatorProperty *Create(const char* type_name); + + virtual void Save(dmlc::JSONWriter *writer) const = 0; + virtual void Load(dmlc::JSONReader *reader) = 0; }; /*! \brief typedef the factory function of operator property */ diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 97eed74a53be..c0eb0e1b0142 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -8,6 +8,8 @@ #define MXNET_SYMBOLIC_H_ #include +#include +#include #include #include #include @@ -64,6 +66,11 @@ class StaticGraph { if (source_id == other.source_id) return index < other.index; return source_id < other.source_id; } + + /*! \brief interface for json serialization */ + void Save(dmlc::JSONWriter *writer) const; + /*! \brief interface for json serialization */ + void Load(dmlc::JSONReader *reader); }; /*! * \brief Operation Node in static graphs. @@ -95,6 +102,21 @@ class StaticGraph { int32_t backward_source_id; /*! \brief default constructor */ Node() : backward_source_id(-1) {} + + friend void swap(Node& lhs, Node& rhs) { + std::swap(lhs.op, rhs.op); + std::swap(lhs.name, rhs.name); + std::swap(lhs.inputs, rhs.inputs); + } + /*! \brief copy constructor in favor of serialization. */ + Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr), + name(another.name), + inputs(another.inputs) {} + + inline Node& operator=(Node another) { + swap(*this, another); + return *this; + } /*! \return whether the node is forward op node */ inline bool is_forward() const { return op != nullptr; @@ -107,6 +129,10 @@ class StaticGraph { inline bool is_variable() const { return op == nullptr && !is_backward(); } + /*! \brief interface for json serialization */ + void Save(dmlc::JSONWriter *writer) const; + /*! \brief interface for json serialization */ + void Load(dmlc::JSONReader *reader); }; /*! \brief all nodes in the graph */ std::vector nodes; @@ -114,6 +140,14 @@ class StaticGraph { std::vector arg_nodes; /*! \brief heads outputs of the graph */ std::vector heads; + /*! \brief load static graph from json. TODO: a static creator's better */ + void Load(const std::string& json); + /*! \brief save static graph to json */ + void Save(std::string* json) const; + /*! \brief interface for json serialization */ + void Save(dmlc::JSONWriter *writer) const; + /*! \brief interface for json serialization */ + void Load(dmlc::JSONReader *reader); // funtions to help inference in static graph /*! * \brief Perform a topological sort on the graph diff --git a/lib/README.md b/lib/README.md deleted file mode 100644 index 24d68ff1acba..000000000000 --- a/lib/README.md +++ /dev/null @@ -1 +0,0 @@ -MXNet library diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index acc05d08d546..208fd8e17d7a 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -349,9 +349,7 @@ def zeros(shape, ctx=None): out: Array The created NArray. """ - if ctx is None: - ctx = Context.default_ctx - arr = NArray(handle=_new_alloc_handle(shape, ctx, False)) + arr = empty(shape, ctx) arr[:] = 0.0 return arr @@ -371,15 +369,11 @@ def ones(shape, ctx=None): out: Array The created NArray. """ - if ctx is None: - ctx = Context.default_ctx - arr = NArray(handle=_new_alloc_handle(shape, ctx, False)) + arr = empty(shape, ctx) arr[:] = 1.0 return arr - - def array(source_array, ctx=None): """Create a new NArray that copies content from source_array. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index f882933538b2..496a3487c55f 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -10,7 +10,7 @@ from .base import NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context -from .narray import NArray +from .narray import NArray, zeros from .executor import Executor @@ -332,6 +332,49 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): raise TypeError('Only Accept list of NArrays or dict of str->NArray') return c_array(NArrayHandle, arg_handles) + def simple_bind(self, ctx, args, grad_req='write'): + """Simply bind current symbol to get an executor + Parameters + ---------- + ctx : Context + The device context the generated executor to run on. + + args : list of NArray or dict of str->NArray + Input arguments to the symbol. + - type is dict of str->NArray, then it maps the name of arguments + to the corresponding NArray, + - Not all the arguments must be provided. + Returns + ------- + executor : mxnet.Executor + The generated Executor + executor_data : dict of str -> list(NArray) + The data for the executor, + key is "args", "grad", "momentum", "auxiliary_states" + sequence is same to the list function + """ + if not isinstance(args, dict): + raise TypeError("args must be dict of str->NArray") + input_shapes = dict((arr[0], arr[1].shape) for arr in args.items()) + arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) + if arg_shapes == None: + raise ValueError("Input node is not complete") + # alloc space + arg_narrays = [] + for name, shape in zip(self.list_arguments(), arg_shapes): + if name in args: + arg_narrays.append(args[name]) + else: + arg_narrays.append(zeros(shape, ctx)) + # TODO(bing): specail treat input data grad + grad_narrays = [zeros(shape, ctx) for shape in arg_shapes] + mom_narrays = [zeros(shape, ctx) for shape in arg_shapes] + aux_narrays = [zeros(shape, ctx) for shape in aux_shapes] + executor = self.bind(ctx, arg_narrays, grad_narrays, grad_req, aux_narrays) + executor_data = {"args" : arg_narrays, "grad" : grad_narrays, + "momentum" : mom_narrays, "auxiliary_states" : aux_narrays} + return (executor, executor_data) + def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): """Bind current symbol to get an executor. diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 43aa4f01637a..37de320d3227 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -35,6 +35,17 @@ struct ActivationParam : public dmlc::Parameter { .add_enum("tanh", kTanh) .describe("Activation function to be applied."); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("act_type", act_type); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("act_type", &act_type); + helper.ReadAllFields(reader); + } }; /** @@ -84,7 +95,7 @@ template Operator* CreateOp(ActivationParam type); #if DMLC_USE_CXX11 -class ActivationProp : public OperatorProperty { +class ActivationProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -135,12 +146,8 @@ class ActivationProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ActivationParam param_; }; #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_ACTIVATION_INL_H_ - diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 0f3b303f85b6..386a039ab222 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -32,6 +32,19 @@ struct BatchNormParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(momentum).set_default(0.1f) .describe("Momentum for moving average"); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("eps", eps); + writer->WriteObjectKeyValue("momentum", momentum); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("eps", &eps); + helper.DeclareField("momentum", &momentum); + helper.ReadAllFields(reader); + } }; template @@ -186,7 +199,7 @@ Operator *CreateOp(BatchNormParam param); #if DMLC_USE_CXX11 -class BatchNormProp : public OperatorProperty { +class BatchNormProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -261,8 +274,9 @@ class BatchNormProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - private: - BatchNormParam param_; + std::vector BackwardResource() const override { + return {Resource::kTempSpace}; + } }; // class BatchNormProp #endif // DMLC_USE_CXX11 diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h index 3eaf47845292..809487ed41b9 100644 --- a/src/operator/concat-inl.h +++ b/src/operator/concat-inl.h @@ -28,6 +28,17 @@ struct ConcatParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(num_args).set_range(1, 6) .describe("Number of inputs to be concated."); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("num_args", num_args); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("num_args", &num_args); + helper.ReadAllFields(reader); + } }; // struct ConcatParam template @@ -163,7 +174,7 @@ template Operator *CreateOp(ConcatParam param); #if DMLC_USE_CXX11 -class ConcatProp : public OperatorProperty { +class ConcatProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -223,9 +234,6 @@ class ConcatProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ConcatParam param_; }; // class ConcatProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index b313aab14f94..7eca474e2c4f 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -30,7 +30,7 @@ struct ConvolutionParam : public dmlc::Parameter { TShape pad; uint32_t num_filter; uint32_t num_group; - uint32_t nstep; + uint32_t workspace; bool no_bias; DMLC_DECLARE_PARAMETER(ConvolutionParam) { int shape[] = {1, 1}; @@ -44,11 +44,47 @@ struct ConvolutionParam : public dmlc::Parameter { .describe("convolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) .describe("number of groups partition"); - DMLC_DECLARE_FIELD(nstep).set_default(2).set_range(1, 10000) - .describe("process n images once"); + DMLC_DECLARE_FIELD(workspace).set_default(128).set_range(1, 10000) + .describe("Tmp workspace for convolution (MB)"); DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + std::string str; + TShape2String(kernel, &str); + writer->WriteObjectKeyValue("kernel", str); + TShape2String(stride, &str); + writer->WriteObjectKeyValue("stride", str); + TShape2String(pad, &str); + writer->WriteObjectKeyValue("pad", str); + writer->WriteObjectKeyValue("num_filter", num_filter); + writer->WriteObjectKeyValue("num_group", num_group); + writer->WriteObjectKeyValue("workspace", workspace); + std::string no_bias_str = no_bias ? "true" : "false"; + writer->WriteObjectKeyValue("no_bias", no_bias_str); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + std::string kernel_str; + helper.DeclareField("kernel", &kernel_str); + std::string stride_str; + helper.DeclareField("stride", &stride_str); + std::string pad_str; + helper.DeclareField("pad", &pad_str); + helper.DeclareField("num_filter", &num_filter); + helper.DeclareField("num_group", &num_group); + helper.DeclareField("workspace", &workspace); + std::string no_bias_str; + helper.DeclareField("no_bias", &no_bias_str); + helper.ReadAllFields(reader); + no_bias = no_bias_str == "true"; + String2TShape(kernel_str, &kernel); + String2TShape(stride_str, &stride); + String2TShape(pad_str, &pad); + } }; template @@ -80,8 +116,8 @@ class ConvolutionOp : public Operator { Tensor out = out_data[kOut].get(s); this->InitTemp(ctx, data.shape_, out.shape_); const index_t nbatch = data.size(0); - for (index_t i = 0; i < nbatch; i += param_.nstep) { - const index_t step = std::min(param_.nstep, nbatch - i); + for (index_t i = 0; i < nbatch; i += nstep_) { + const index_t step = std::min(nstep_, nbatch - i); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * step)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], @@ -148,8 +184,8 @@ class ConvolutionOp : public Operator { Tensor gwmat = in_grad[kWeight].get_with_shape(wmat_shape, s); this->InitTemp(ctx, data.shape_, grad.shape_); const index_t nbatch = data.size(0); - for (index_t i = 0; i < nbatch; i += param_.nstep) { - const index_t step = std::min(param_.nstep, nbatch - i); + for (index_t i = 0; i < nbatch; i += nstep_) { + const index_t step = std::min(nstep_, nbatch - i); temp_col_.Resize(Shape2(shape_colunit_[0], shape_colunit_[1] * step)); temp_dst_.Resize(Shape3(shape_dstunit_[0], @@ -220,16 +256,19 @@ class ConvolutionOp : public Operator { shape_dstunit_ = mshadow::Shape3(param_.num_group, param_.num_filter / param_.num_group, oshape[2] * oshape[3]); - int nop = (ishape[0] + param_.nstep - 1) / param_.nstep; - param_.nstep = (ishape[0] + nop - 1) / nop; + const uint32_t workspace_size = param_.workspace << 18; + nstep_ = std::max(std::min(static_cast(workspace_size / shape_colunit_.Size()), + ishape[0]), 1U); + int nop = (ishape[0] + nstep_ - 1) / nstep_; + nstep_ = (ishape[0] + nop - 1) / nop; mshadow::Stream *s = ctx.get_stream(); temp_col_.set_stream(s); temp_dst_.set_stream(s); temp_col_.Resize(mshadow::Shape2(shape_colunit_[0], - shape_colunit_[1] * param_.nstep)); + shape_colunit_[1] * nstep_)); temp_dst_.Resize(mshadow::Shape3(shape_dstunit_[0], shape_dstunit_[1], - shape_dstunit_[2] * param_.nstep)); + shape_dstunit_[2] * nstep_)); } ConvolutionParam param_; @@ -238,13 +277,14 @@ class ConvolutionOp : public Operator { mshadow::TensorContainer temp_dst_; mshadow::Shape<2> shape_colunit_; mshadow::Shape<3> shape_dstunit_; + index_t nstep_; }; // class ConvolutionOp template Operator* CreateOp(ConvolutionParam param); #if DMLC_USE_CXX11 -class ConvolutionProp : public OperatorProperty { +class ConvolutionProp : public ParamOperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { @@ -328,8 +368,13 @@ class ConvolutionProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - private: - ConvolutionParam param_; + std::vector ForwardResource() const override { + return {Resource::kTempSpace}; + } + + std::vector BackwardResource() const override { + return {Resource::kTempSpace}; + } }; // class ConvolutionProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/elementwise_binary_op-inl.h b/src/operator/elementwise_binary_op-inl.h index b3ae8adc3de1..e549cc9966fd 100644 --- a/src/operator/elementwise_binary_op-inl.h +++ b/src/operator/elementwise_binary_op-inl.h @@ -153,7 +153,7 @@ Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type); #if DMLC_USE_CXX11 template -class ElementWiseBinaryOpProp : public OperatorProperty { +class ElementWiseBinaryOpProp : public NoParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { CHECK_EQ(kwargs.size(), 0) diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h index c2a890b2e976..06ac83ad948e 100644 --- a/src/operator/elementwise_sum-inl.h +++ b/src/operator/elementwise_sum-inl.h @@ -30,6 +30,17 @@ struct ElementWiseSumParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(num_args).set_range(1, 100) .describe("Number of inputs to be sumed."); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("num_args", num_args); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("num_args", &num_args); + helper.ReadAllFields(reader); + } }; template @@ -102,6 +113,16 @@ class ElementWiseSumOp : public Operator { Assign(igrad, req[i], F(ograd)); } } + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("size_", size_); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("size_", &size_); + helper.ReadAllFields(reader); + } private: int size_; @@ -111,7 +132,7 @@ template Operator* CreateOp(ElementWiseSumParam param); #if DMLC_USE_CXX11 -class ElementWiseSumProp : public OperatorProperty { +class ElementWiseSumProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -179,9 +200,6 @@ class ElementWiseSumProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ElementWiseSumParam param_; }; // class ElementWiseSumProp #endif // DMLC_USE_CXX11 diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 75fe14d3aab8..82b075278a41 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -35,6 +35,21 @@ struct FullyConnectedParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(no_bias).set_default(false) .describe("Whether to disable bias parameter."); } + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("num_hidden", num_hidden); + std::string no_bias_str = no_bias ? "true" : "false"; + writer->WriteObjectKeyValue("no_bias", no_bias_str); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("num_hidden", &num_hidden); + std::string no_bias_str; + helper.DeclareField("no_bias", &no_bias_str); + helper.ReadAllFields(reader); + no_bias = no_bias_str == "true"; + } }; /** @@ -116,7 +131,7 @@ template Operator* CreateOp(FullyConnectedParam param); #if DMLC_USE_CXX11 -class FullyConnectedProp : public OperatorProperty { +class FullyConnectedProp : public ParamOperatorProperty { public: std::vector ListArguments() const override { if (!param_.no_bias) { @@ -181,9 +196,6 @@ class FullyConnectedProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - FullyConnectedParam param_; }; // class FullyConnectedSymbol #endif } // namespace op diff --git a/src/operator/operator_common.cc b/src/operator/operator_common.cc new file mode 100644 index 000000000000..c43b1c79d75e --- /dev/null +++ b/src/operator/operator_common.cc @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file operator_common.cc + * \brief implementation of common internal functions. + */ + +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +/*! \brief helper static function to read TShape. */ +void String2TShape(const std::string& str, TShape* shape) { + std::istringstream iss(str); + iss >> *shape; +} + +/*! \brief helper static function to write TShape. */ +void TShape2String(const TShape& shape, std::string* str) { + str->clear(); + std::ostringstream oss; + oss << shape; + *str = oss.str(); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index eea731c8fbe6..70ef3aea476f 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -8,9 +8,12 @@ #ifndef MXNET_OPERATOR_OPERATOR_COMMON_H_ #define MXNET_OPERATOR_OPERATOR_COMMON_H_ +#include #include #include #include +#include +#include #include namespace mxnet { @@ -87,6 +90,43 @@ struct InferShapeError { } #endif +#if DMLC_USE_CXX11 +template +class ParamOperatorProperty : public OperatorProperty { + public: + ParamOperatorProperty() {} + explicit ParamOperatorProperty(Param param) : param_(param) {} + inline void Save(dmlc::JSONWriter *writer) const { + param_.Save(writer); + } + inline void Load(dmlc::JSONReader *reader) { + param_.Load(reader); + } + inline bool operator==(const ParamOperatorProperty& other) const { + return param_ == other.param_; + } + protected: + Param param_; +}; + +class NoParamOperatorProperty : public OperatorProperty { + public: + inline void Save(dmlc::JSONWriter *writer) const { + } + inline void Load(dmlc::JSONReader *reader) { + } + inline bool operator==(const NoParamOperatorProperty& other) const { + return true; + } +}; +#endif // DMLC_USE_CXX11 +/*! \brief helper static function to read TShape. */ +void String2TShape(const std::string& str, TShape* shape); + +/*! \brief helper static function to write TShape. */ +void TShape2String(const TShape& shape, std::string* str); + + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/param.h b/src/operator/param.h index f0ce5886e2fb..9b08c197a160 100644 --- a/src/operator/param.h +++ b/src/operator/param.h @@ -71,4 +71,3 @@ struct Param { } // namespace mxnet #endif // MXNET_OPERATOR_PARAM_H_ - diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h index b0d483ef0217..a40abbbbdb38 100644 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -51,6 +51,33 @@ struct PoolingParam : public dmlc::Parameter { .set_expect_ndim(2) .describe("pad for pooling: (y, x)"); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + std::string str; + TShape2String(kernel, &str); + writer->WriteObjectKeyValue("kernel", str); + TShape2String(stride, &str); + writer->WriteObjectKeyValue("stride", str); + TShape2String(pad, &str); + writer->WriteObjectKeyValue("pad", str); + writer->WriteObjectKeyValue("pool_type", pool_type); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + std::string kernel_str; + helper.DeclareField("kernel", &kernel_str); + std::string stride_str; + helper.DeclareField("stride", &stride_str); + std::string pad_str; + helper.DeclareField("pad", &pad_str); + helper.DeclareField("pool_type", &pool_type); + helper.ReadAllFields(reader); + String2TShape(kernel_str, &kernel); + String2TShape(stride_str, &stride); + String2TShape(pad_str, &pad); + } }; template @@ -154,7 +181,7 @@ Operator* CreateOp(PoolingParam param); #if DMLC_USE_CXX11 -class PoolingProp : public OperatorProperty { +class PoolingProp : public ParamOperatorProperty { public: void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -205,9 +232,6 @@ class PoolingProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - PoolingParam param_; }; // class PoolingProp #endif // DMLC_USE_CXX11 } // namespace op diff --git a/src/operator/reshape-inl.h b/src/operator/reshape-inl.h index 8bd95c49927e..baf2feaa013c 100644 --- a/src/operator/reshape-inl.h +++ b/src/operator/reshape-inl.h @@ -28,6 +28,21 @@ struct ReshapeParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(ReshapeParam) { DMLC_DECLARE_FIELD(target_shape).describe("Target new shape"); } + + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + std::string str; + TShape2String(target_shape, &str); + writer->WriteObjectKeyValue("target_shape", str); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + std::string target_shape_str; + helper.DeclareField("target_shape", &target_shape_str); + helper.ReadAllFields(reader); + String2TShape(target_shape_str, &target_shape); + } }; template @@ -83,11 +98,11 @@ template Operator* CreateOp(); #if DMLC_USE_CXX11 -class ReshapeProp : public OperatorProperty { +class ReshapeProp : public ParamOperatorProperty { public: ReshapeProp() {} - explicit ReshapeProp(ReshapeParam param) : param_(param) {} + explicit ReshapeProp(ReshapeParam param) : ParamOperatorProperty(param) {} void Init(const std::vector >& kwargs) override { param_.Init(kwargs); @@ -140,9 +155,6 @@ class ReshapeProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - ReshapeParam param_; }; // class ReshapeProp class FlattenProp : public ReshapeProp { diff --git a/src/operator/softmax-inl.h b/src/operator/softmax-inl.h index cf4e2671d719..d6f34c971d49 100644 --- a/src/operator/softmax-inl.h +++ b/src/operator/softmax-inl.h @@ -29,6 +29,16 @@ struct SoftmaxParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) .describe("Scale the gradient by a float factor"); }; + inline void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("grad_scale", grad_scale); + writer->EndObject(); + } + inline void Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("grad_scale", &grad_scale); + helper.ReadAllFields(reader); + } }; template @@ -83,7 +93,7 @@ template Operator* CreateOp(SoftmaxParam param); #if DMLC_USE_CXX11 -class SoftmaxProp : public OperatorProperty { +class SoftmaxProp : public ParamOperatorProperty { public: std::vector ListArguments() const override { return {"data", "label"}; @@ -138,9 +148,6 @@ class SoftmaxProp : public OperatorProperty { } Operator* CreateOperator(Context ctx) const; - - private: - SoftmaxParam param_; }; // class SoftmaxProp #endif // DMLC_USE_CXX11 diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index c24ee1a085d5..6ffdc63d4188 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -280,4 +280,91 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, } } } + +void StaticGraph::DataEntry::Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("source_id", source_id); + writer->WriteObjectKeyValue("index", index); + writer->EndObject(); +} + +void StaticGraph::DataEntry::Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("source_id", &source_id); + helper.DeclareField("index", &index); + helper.ReadAllFields(reader); +} + +void StaticGraph::Node::Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + if (op.get() != nullptr) { + writer->WriteObjectKeyValue("op_type", op.get()->TypeString()); + std::ostringstream os; + dmlc::JSONWriter subWriter(&os); + subWriter.BeginObject(); + subWriter.WriteObjectKeyValue("op", *(op.get())); + subWriter.EndObject(); + writer->WriteObjectKeyValue("op", os.str()); + } else { + std::string jsonNull = "null"; + writer->WriteObjectKeyValue("op_type", jsonNull); + writer->WriteObjectKeyValue("op", jsonNull); + } + writer->WriteObjectKeyValue("name", name); + writer->WriteObjectKeyValue("inputs", inputs); + writer->WriteObjectKeyValue("backward_source_id", backward_source_id); + writer->EndObject(); +} + +void StaticGraph::Node::Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper firstHelper; + std::string op_type_str; + firstHelper.DeclareField("op_type", &op_type_str); + std::string op_str; + firstHelper.DeclareField("op", &op_str); + firstHelper.DeclareField("name", &name); + firstHelper.DeclareField("inputs", &inputs); + firstHelper.DeclareField("backward_source_id", &backward_source_id); + firstHelper.ReadAllFields(reader); + if (op_type_str != "null") { + dmlc::JSONObjectReadHelper secondHelper; + std::istringstream iss(op_str); + dmlc::JSONReader subReader(&iss); + op.reset(OperatorProperty::Create(op_type_str.c_str())); + secondHelper.DeclareField("op", op.get()); + secondHelper.ReadAllFields(reader); + } else { + op.reset(nullptr); + } +} + +void StaticGraph::Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", nodes); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", heads); + writer->EndObject(); +} + +void StaticGraph::Load(dmlc::JSONReader *reader) { + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("nodes", &nodes); + helper.DeclareField("arg_nodes", &arg_nodes); + helper.DeclareField("heads", &heads); + helper.ReadAllFields(reader); +} + +void StaticGraph::Load(const std::string& json) { + std::istringstream is(json); + dmlc::JSONReader reader(&is); + reader.Read(this); +} + +void StaticGraph::Save(std::string* json) const { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + writer.Write(*this); + *json = os.str(); +} + } // namespace mxnet diff --git a/tests/python/test_conv.py b/tests/python/test_conv.py index 9ab34ce1c8ae..d63a0542ce7a 100644 --- a/tests/python/test_conv.py +++ b/tests/python/test_conv.py @@ -12,12 +12,12 @@ def CalAcc(out, label): # symbol net batch_size = 100 data = mx.symbol.Variable('data') -conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') -conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2), nstep=100) +conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2)) bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max')