Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 10 additions & 103 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,127 +156,34 @@ def RandomInit(narray):
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="loss")


epoch = 9
lr = 0.05
wd = 0.0001
momentum = 0.9

batch_size = 128
data_shape = (batch_size, 3, 28, 28)

in_data = mx.nd.empty(data_shape, mx.gpu())
executor = loss.simple_bind(mx.gpu(), data = in_data)


out_narray = executor.outputs[0]
pred = mx.nd.zeros(out_narray.shape, mx.cpu())

arg_narrays, grad_narrays = executor.list_arguments()
inputs = dict(zip(loss.list_arguments(), arg_narrays))
tmp_label = mx.nd.zeros(inputs["loss_label"].shape)
momentum_narrays = [mx.nd.zeros(item.shape, mx.gpu()) for item in grad_narrays]

block = list(zip(grad_narrays, arg_narrays, momentum_narrays))

np.random.seed(0)


for name, narray in inputs.items():
if "weight" in name:
narray[:] = np.random.uniform(-0.1, 0.1, narray.shape)
if "bias" in name:
narray[:] = 0.0
if "gamma" in name:
narray[:] = 1.0
if "beta" in name:
narray[:] = 0.0

def Update(grad, weight, mom):
mom[:] *= momentum
mom[:] += -lr * (grad / batch_size + wd * weight)
weight[:] += mom

#check data
#########################################################
get_data.GetCifar10()

batch_size = 128
epoch = 3
train_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
shuffle=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=4,
prefetch_capacity=6)
nthread=1)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=False,
rand_mirror=False,
shuffle=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=4,
prefetch_capacity=6)

def progress(count, total, epoch, toc):
bar_len = 50
filled_len = int(round(bar_len * count / float(total)))

percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)
tic = time.time()
speed = batch_size / float(tic - toc)
suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed)
sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix))

nthread=1)

def test_cifar():
acc_train = 0.
acc_val = 0.
print("Start training...")
for i in range(epoch):
# train
train_acc = 0.0
val_acc = 0.0
train_nbatch = 0
val_nbatch = 0
all_train_bacth = round(50000 / float(batch_size) + 1)
for data, label in train_dataiter:
toc = time.time()
label = label.asnumpy().flatten()
tmp_label[:] = label
inputs["data"][:] = data
inputs["loss_label"][:] = tmp_label
executor.forward()
pred[:] = out_narray
train_acc += CalAcc(pred.asnumpy(), label)
train_nbatch += 1
#executor.backward([out_narray])
executor.backward()

for grad, weight, mom in block:
Update(grad, weight, mom)
progress(train_nbatch, all_train_bacth, i, toc)
model = mx.model.MXNetModel(ctx=mx.gpu(),
symbol=loss, data=(batch_size, 3, 28, 28),
optimizer="sgd", num_round = epoch, batch_size = batch_size,
learning_rate=0.05, momentum=0.9, weight_decay=0.00001)
model.fit(X=train_dataiter, eval_set=test_dataiter, eval_metric=CalAcc)

# evaluate
for data, label in test_dataiter:
label = label.asnumpy().flatten()
inputs["data"][:] = data
executor.forward()
pred[:] = out_narray
val_acc += CalAcc(pred.asnumpy(), label)
val_nbatch += 1
acc_train = train_acc / train_nbatch
acc_val = val_acc / val_nbatch
sys.stdout.write('\n')
print("Train Acc: ", train_acc / train_nbatch)
print("Valid Acc: ", val_acc / val_nbatch)
train_dataiter.reset()
test_dataiter.reset()

if __name__ == "__main__":
test_cifar()
2 changes: 1 addition & 1 deletion mshadow
4 changes: 4 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,9 @@
# use mx.nd as short for mx.ndarray
from . import ndarray as nd
from . import random
from . import optimizer
from . import model
from . import initializer
import atexit

__version__ = "0.1.0"
124 changes: 124 additions & 0 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# pylint: skip-file
import numpy as np
from .ndarray import NDArray
from . import random

class Initializer(object):
"""Base class for Initializer"""
def __init__(self, **kwargs):
"""Constructor

Parameters
----------
kwargs: dict
potential parameters for Initializer implmentation
"""
self.args = kwargs

def init_weight(self):
"""Abstruct method to Initialize weight"""
raise NotImplementedError("Must override it")

def __call__(self, state, arr):
"""Override () function to do Initialization

Parameters:
----------
state: str
name of corrosponding ndarray
arr: NDArray
ndarray to be Initialized
"""
assert(isinstance(state, str))
assert(isinstance(arr, NDArray))
if "weight" in state:
self.init_weight(arr)
if "bias" in state:
arr[:] = 0.0
if "gamma" in state:
arr[:] = 1.0
if "beta" in state:
arr[:] = 0.0

def get_fan(self, shape):
"""Get input/output from shape

Parameter
---------
shape: tuple
shape of NDArray

Returns
-------
fan_in: int
input dim
fan_out: int
output dim
"""
fan_in = shape[1]
fan_out = shape[0]
return fan_in, fan_out

class Uniform(Initializer):
"""Uniform Initializer"""
def __init__(self, scale=0.07):
"""Constructor

Parameter
---------
scale: float (default=0.07)
unifrom range [-scale, scale]
"""
super().__init__(scale = scale)

def init_weight(self, arr):
"""Implmentation of abs method

Parameter
--------
arr: NDArray
NDArray to be Initialized
"""
if isinstance(arr, NDArray):
arr[:] = random.uniform(-scale, scale, arr.shape)
else:
raise TypeError("Input array must be NDArray")

class Normal(Initializer):
"""Gaussian Initializer"""
def __init__(self, sigma=0.01):
"""Constuctor of Normal Initializer
Parameter
--------
sigma: float (default=0.01)
sigma for gaussian distribution
"""
super().__init__(sigma = sigma)
def init_weight(self, arr):
"""Implmentation of abs method

Parameter
--------
arr: NDArray
NDArray to be Initialized
"""
if isinstance(arr, NDArray):
arr[:] = random.normal(0, sigma, arr.shape)
else:
raise TypeError("Input array must be NDArray")

class Xavier(Initializer):
def init_weight(self, arr):
"""Implmentation of abs method

Parameter
--------
arr: NDArray
NDArray to be Initialized
"""
if isinstance(arr, NDArray):
fan_in, fan_out = self.get_fan(arr.shape)
s = np.sqrt(6. / (fan_in + fan_out))
arr[:] = random.uniform(-s, s, arr.shape)
else:
raise TypeError("Input array must be NDArray")
Loading