diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 7c2d20789365..cc9bafadaf70 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -156,127 +156,34 @@ def RandomInit(narray): flatten = mx.symbol.Flatten(data=pool, name="flatten1") fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") loss = mx.symbol.Softmax(data=fc, name="loss") - - -epoch = 9 -lr = 0.05 -wd = 0.0001 -momentum = 0.9 - -batch_size = 128 -data_shape = (batch_size, 3, 28, 28) - -in_data = mx.nd.empty(data_shape, mx.gpu()) -executor = loss.simple_bind(mx.gpu(), data = in_data) - - -out_narray = executor.outputs[0] -pred = mx.nd.zeros(out_narray.shape, mx.cpu()) - -arg_narrays, grad_narrays = executor.list_arguments() -inputs = dict(zip(loss.list_arguments(), arg_narrays)) -tmp_label = mx.nd.zeros(inputs["loss_label"].shape) -momentum_narrays = [mx.nd.zeros(item.shape, mx.gpu()) for item in grad_narrays] - -block = list(zip(grad_narrays, arg_narrays, momentum_narrays)) - -np.random.seed(0) - - -for name, narray in inputs.items(): - if "weight" in name: - narray[:] = np.random.uniform(-0.1, 0.1, narray.shape) - if "bias" in name: - narray[:] = 0.0 - if "gamma" in name: - narray[:] = 1.0 - if "beta" in name: - narray[:] = 0.0 - -def Update(grad, weight, mom): - mom[:] *= momentum - mom[:] += -lr * (grad / batch_size + wd * weight) - weight[:] += mom - -#check data +######################################################### get_data.GetCifar10() - +batch_size = 128 +epoch = 3 train_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar_mean.bin", rand_crop=True, rand_mirror=True, - shuffle=False, input_shape=(3,28,28), batch_size=batch_size, - nthread=4, - prefetch_capacity=6) + nthread=1) test_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/test.rec", mean_img="data/cifar/cifar_mean.bin", rand_crop=False, rand_mirror=False, - shuffle=False, input_shape=(3,28,28), batch_size=batch_size, - nthread=4, - prefetch_capacity=6) - -def progress(count, total, epoch, toc): - bar_len = 50 - filled_len = int(round(bar_len * count / float(total))) - - percents = round(100.0 * count / float(total), 1) - bar = '=' * filled_len + '-' * (bar_len - filled_len) - tic = time.time() - speed = batch_size / float(tic - toc) - suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) - sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) - + nthread=1) def test_cifar(): - acc_train = 0. - acc_val = 0. - print("Start training...") - for i in range(epoch): - # train - train_acc = 0.0 - val_acc = 0.0 - train_nbatch = 0 - val_nbatch = 0 - all_train_bacth = round(50000 / float(batch_size) + 1) - for data, label in train_dataiter: - toc = time.time() - label = label.asnumpy().flatten() - tmp_label[:] = label - inputs["data"][:] = data - inputs["loss_label"][:] = tmp_label - executor.forward() - pred[:] = out_narray - train_acc += CalAcc(pred.asnumpy(), label) - train_nbatch += 1 - #executor.backward([out_narray]) - executor.backward() - - for grad, weight, mom in block: - Update(grad, weight, mom) - progress(train_nbatch, all_train_bacth, i, toc) + model = mx.model.MXNetModel(ctx=mx.gpu(), + symbol=loss, data=(batch_size, 3, 28, 28), + optimizer="sgd", num_round = epoch, batch_size = batch_size, + learning_rate=0.05, momentum=0.9, weight_decay=0.00001) + model.fit(X=train_dataiter, eval_set=test_dataiter, eval_metric=CalAcc) - # evaluate - for data, label in test_dataiter: - label = label.asnumpy().flatten() - inputs["data"][:] = data - executor.forward() - pred[:] = out_narray - val_acc += CalAcc(pred.asnumpy(), label) - val_nbatch += 1 - acc_train = train_acc / train_nbatch - acc_val = val_acc / val_nbatch - sys.stdout.write('\n') - print("Train Acc: ", train_acc / train_nbatch) - print("Valid Acc: ", val_acc / val_nbatch) - train_dataiter.reset() - test_dataiter.reset() if __name__ == "__main__": test_cifar() diff --git a/mshadow b/mshadow index 2b6c218f6f6f..c6f53473ee4b 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 2b6c218f6f6fd677186eee9eb0a9ff64a57ead70 +Subproject commit c6f53473ee4bfd834bf38cd3ff630e395ff662b4 diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 287cd837275e..e3b17baa1b31 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -18,5 +18,9 @@ # use mx.nd as short for mx.ndarray from . import ndarray as nd from . import random +from . import optimizer +from . import model +from . import initializer +import atexit __version__ = "0.1.0" diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py new file mode 100644 index 000000000000..6a3bbb385927 --- /dev/null +++ b/python/mxnet/initializer.py @@ -0,0 +1,124 @@ +# pylint: skip-file +import numpy as np +from .ndarray import NDArray +from . import random + +class Initializer(object): + """Base class for Initializer""" + def __init__(self, **kwargs): + """Constructor + + Parameters + ---------- + kwargs: dict + potential parameters for Initializer implmentation + """ + self.args = kwargs + + def init_weight(self): + """Abstruct method to Initialize weight""" + raise NotImplementedError("Must override it") + + def __call__(self, state, arr): + """Override () function to do Initialization + + Parameters: + ---------- + state: str + name of corrosponding ndarray + arr: NDArray + ndarray to be Initialized + """ + assert(isinstance(state, str)) + assert(isinstance(arr, NDArray)) + if "weight" in state: + self.init_weight(arr) + if "bias" in state: + arr[:] = 0.0 + if "gamma" in state: + arr[:] = 1.0 + if "beta" in state: + arr[:] = 0.0 + + def get_fan(self, shape): + """Get input/output from shape + + Parameter + --------- + shape: tuple + shape of NDArray + + Returns + ------- + fan_in: int + input dim + fan_out: int + output dim + """ + fan_in = shape[1] + fan_out = shape[0] + return fan_in, fan_out + +class Uniform(Initializer): + """Uniform Initializer""" + def __init__(self, scale=0.07): + """Constructor + + Parameter + --------- + scale: float (default=0.07) + unifrom range [-scale, scale] + """ + super().__init__(scale = scale) + + def init_weight(self, arr): + """Implmentation of abs method + + Parameter + -------- + arr: NDArray + NDArray to be Initialized + """ + if isinstance(arr, NDArray): + arr[:] = random.uniform(-scale, scale, arr.shape) + else: + raise TypeError("Input array must be NDArray") + +class Normal(Initializer): + """Gaussian Initializer""" + def __init__(self, sigma=0.01): + """Constuctor of Normal Initializer + Parameter + -------- + sigma: float (default=0.01) + sigma for gaussian distribution + """ + super().__init__(sigma = sigma) + def init_weight(self, arr): + """Implmentation of abs method + + Parameter + -------- + arr: NDArray + NDArray to be Initialized + """ + if isinstance(arr, NDArray): + arr[:] = random.normal(0, sigma, arr.shape) + else: + raise TypeError("Input array must be NDArray") + +class Xavier(Initializer): + def init_weight(self, arr): + """Implmentation of abs method + + Parameter + -------- + arr: NDArray + NDArray to be Initialized + """ + if isinstance(arr, NDArray): + fan_in, fan_out = self.get_fan(arr.shape) + s = np.sqrt(6. / (fan_in + fan_out)) + arr[:] = random.uniform(-s, s, arr.shape) + else: + raise TypeError("Input array must be NDArray") diff --git a/python/mxnet/model.py b/python/mxnet/model.py new file mode 100644 index 000000000000..e77761d2135b --- /dev/null +++ b/python/mxnet/model.py @@ -0,0 +1,169 @@ +# pylint: skip-file +import numpy as np +import time + +from .io import DataIter +from .context import Context +from .ndarray import empty, zeros +from .initializer import Xavier +from .symbol import Symbol +from .optimizer import get_optimizer + +Base = object +try: + from sklearn.base import BaseEstimator + Base = BaseEstimator +except ImportError: + SKLEARN_INSTALLED = False + + +class MXNetModel(object): + """MXNet model""" + def __init__(self, ctx, symbol, num_round, batch_size, optimizer="sgd", initializer=Xavier(), **kwargs): + """Constructor + + Parameter + --------- + ctx: Context or list of Context + running context for model, if is a list, run a multiply device + symbol: Symbol + symbol of the model + num_round: int + training num round + batch_size: int + batch size + optimizer: str + optimizer used to train the model + initializer: Initializer + initializer used to initialize weight + kwargs: dict + optimizer arguments and input data shape + """ + if not isinstance(symbol, Symbol): + raise TypeError("symbol") + if num_round <= 0: + raise ValueError("num_round must be greater than 0") + self.ctx = ctx + self.optimizer = get_optimizer(name=optimizer, batch_size=batch_size, **kwargs) + self.num_round = num_round + self.initializer = initializer + self.shape_dict = kwargs + self.symbol = symbol + # check shape and batch size + arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape(**kwargs) + if arg_shapes == None: + raise ValueError("input shape is incomplete") + + def fit(self, X, y=None, eval_set=None, eval_metric=None): + """fit the model + + Parameter + --------- + X: DataIter or numpy.ndarray(TODO) + training data + y: None or numpy.ndarray + if X is DataIter no need to set (use None) + if X is numpy.ndarray y is required to set + eval_set: DataIter or numpy.ndarray pair (TODO) + if eval_set is numpy.ndarray pair, it should be (valid_data, valid_label) + eval_metric: function + """ + self.executor = self.symbol.simple_bind(ctx=self.ctx, **self.shape_dict) + # init + arg_narrays, grad_narrays = self.executor.list_arguments() + inputs = dict(zip(self.symbol.list_arguments(), arg_narrays)) + arg_blocks = list(zip(arg_narrays, grad_narrays, self.symbol.list_arguments())) + # only support 1 output now + # find label + label_node_name = "" + data_node_name = "" + for name, ndarray in inputs.items(): + if "label" in name: + label_node_name = name + if "data" in name: + data_node_name = name + # single output + out_ndarray = self.executor.outputs[0] + pred = zeros(out_ndarray.shape) + for state, narray in inputs.items(): + self.initializer(state, narray) + for i in range(self.num_round): + print("Epoch %d:" % i) + #train + train_acc = 0.0 + val_acc = 0.0 + train_nbatch = 0 + val_nbatch = 0 + tic = time.time() + for data, label in X: + label = label.asnumpy().flatten() + inputs[label_node_name][:] = label + inputs[data_node_name][:] = data + self.executor.forward() + pred[:] = out_ndarray + train_nbatch += 1 + self.executor.backward() + + for weight, grad, state in arg_blocks: + self.optimizer(weight, grad, state) + + train_acc += eval_metric(pred.asnumpy(), label) + toc = time.time() + print("Time: %.3f" % (toc - tic)) + + # eval + for data, label in eval_set: + label = label.asnumpy().flatten() + inputs[data_node_name][:] = data + self.executor.forward() + val_acc += eval_metric(out_ndarray.asnumpy(), label) + val_nbatch += 1 + + print("Train Acc: ", train_acc / train_nbatch) + print("Valid Acc: ", val_acc / val_nbatch) + X.reset() + eval_set.reset() + + def save(self, path): + """save model + + Parameter + --------- + path: str + saving path + """ + raise NotImplementedError("TODO") + + def load(self, path): + """load model + + Parameter + --------- + path: str + saving path + """ + raise NotImplementedError("TODO") + + def draw(self, path): + """draw model + + Parameter + --------- + path: str + saving path + """ + raise NotImplementedError("TODO") + +""" +class MXNetClassifier(MXNetModel): + def __init__(self, ctx, symbol, optimizer, num_round, batch_size, initializer=xavier, **kwargs): + super(MXNetClassifier, self).__init__(ctx, symbol, optimizer, + num_round, batch_size, initializer, **kwargs) + + def predict(self): + pass + def predict_proba(self, X): + pass + def score(self, X, y): + pass +""" diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py new file mode 100644 index 000000000000..fdbc2e196c5e --- /dev/null +++ b/python/mxnet/optimizer.py @@ -0,0 +1,67 @@ +# pylint: skip-file +from .ndarray import NDArray, zeros + +def get_optimizer(name, batch_size=1, **kwargs): + """Optimizer factory + + Parameters + ---------- + name: str + Name of required optimizer + batch_size: int + batch size, used to normalize gradient + kwargs: dict + Parameters for optimizer + + Return + ---------- + A required optimizer object + + """ + if name == "sgd" or name == "SGD": + return SGD(batch_size=batch_size, **kwargs) + else: + raise Exception("Not implemented") + +class SGD(object): + """A very simple SGD optimizer with Nesterov method""" + def __init__(self, learning_rate=0.01, momentum=0.9, weight_decay=0.0001, batch_size=1, **kwargs): + """ + Parameter + ---------- + learning_rate: float + learning_rate value + momentum: float + momentum value + weight_decay: float + L2 regularization coefficient + batch_size: int + batch size, used to norm gradient + """ + + self.lr = learning_rate + self.momentum = momentum + self.wd = weight_decay + self.batch_size = batch_size + self.momentums = {} + + def __call__(self, weight, grad, states): + """ + Parameter + --------- + weight: NDArray + weight ndarray + grad: NDArray + grad ndarray + states: str + name of weight + """ + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + if states not in self.momentums: + self.momentums[states] = zeros(grad.shape, grad.context) + mom = self.momentums[states] + mom[:] *= self.momentum + mom[:] += -self.lr * (grad / self.batch_size + self.wd * weight) + weight[:] += mom + diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 1baa70588274..589503727ce1 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -245,11 +245,10 @@ def infer_shape(self, *args, **kwargs): else: keys = [] for k, v in kwargs.items(): - keys.append(c_str(k)) - if not isinstance(v, tuple): - raise TypeError('Argument need to be shapes(tuple)') - sdata.extend(v) - indptr.append(len(sdata)) + if isinstance(v, tuple): + keys.append(c_str(k)) + sdata.extend(v) + indptr.append(len(sdata)) arg_shape_size = mx_uint() arg_shape_ndim = ctypes.POINTER(mx_uint)() arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() @@ -405,31 +404,31 @@ def simple_bind(self, ctx, grad_req='write', **kwargs): - 'write' means everytime gradient is write to specified args_grad NDArray. - 'add' means everytime gradient is add to the specified NDArray. - 'null' means no action is taken, the gradient may not be calculated. - kwargs : dict of str to NDArray + kwargs : dict of str->shape + Input shape dictionary, name->shape Returns ------- executor : mxnet.Executor The generated Executor """ - input_shapes = dict((name, arr.shape) for name, arr in kwargs.items()) # pylint: disable=unused-variable - arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) + arg_shapes, out_shapes, aux_shapes = self.infer_shape(**kwargs) # pylint: enable=unused-variable if arg_shapes == None: raise ValueError("Input node is not complete") # alloc space - arg_ndarrays = [] - for name, shape in zip(self.list_arguments(), arg_shapes): - if name in kwargs: - arg_ndarrays.append(kwargs[name]) + arg_ndarrays = [zeros(shape, ctx) for shape in arg_shapes] + req = {} + for state in self.list_arguments(): + if "data" in state: + req[state] = "null" else: - arg_ndarrays.append(zeros(shape, ctx)) - # TODO(bing): specail treat input data grad + req[state] = grad_req # TODO(bing): not generate grad case grad_ndarrays = [zeros(shape, ctx) for shape in arg_shapes] aux_ndarrays = [zeros(shape, ctx) for shape in aux_shapes] - executor = self.bind(ctx, arg_ndarrays, grad_ndarrays, grad_req, aux_ndarrays) + executor = self.bind(ctx, arg_ndarrays, grad_ndarrays, req, aux_ndarrays) return executor def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): @@ -522,7 +521,7 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): req_array.append(mx_uint(req_map[grad_req[name]])) else: req_array.append(mx_uint(0)) - req_array = c_array(mx_uint, req_array) + reqs_array = c_array(mx_uint, req_array) handle = ExecutorHandle() check_call(_LIB.MXExecutorBind(self.handle,