Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def verify_loaded_model(net):
return data.astype(np.float32)/255, label.astype(np.float32)

# Load ten random images from the test dataset
sample_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),
sample_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False).transform(transform),
10, shuffle=True)

for data, label in sample_data:
Expand Down
9 changes: 4 additions & 5 deletions example/distributed_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ We can then create a `DataLoader` using the `SplitSampler` like shown below:

```python
# Load the training data
train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True, transform=transform),
train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True).transform(transform),
batch_size,
sampler=SplitSampler(50000, store.num_workers, store.rank))
```
Expand All @@ -141,7 +141,7 @@ def train_batch(batch, ctx, net, trainer):
# Split and load data into multiple GPUs
data = batch[0]
data = gluon.utils.split_and_load(data, ctx)

# Split and load label into multiple GPUs
label = batch[1]
label = gluon.utils.split_and_load(label, ctx)
Expand Down Expand Up @@ -204,7 +204,7 @@ python ~/mxnet/tools/launch.py -n 2 -s 2 -H hosts \
Let's take a look at the `hosts` file.

```
~/dist$ cat hosts
~/dist$ cat hosts
d1
d2
```
Expand Down Expand Up @@ -232,7 +232,7 @@ Last login: Wed Jan 31 18:06:45 2018 from 72.21.198.67
Note that no authentication information was provided to login to the host. This can be done using multiple methods. One easy way is to specify the ssh certificates in `~/.ssh/config`. Example:

```
~$ cat ~/.ssh/config
~$ cat ~/.ssh/config
Host d1
HostName ec2-34-201-108-233.compute-1.amazonaws.com
port 22
Expand Down Expand Up @@ -269,4 +269,3 @@ Epoch 4: Test_acc 0.687900
```

Note that the output from all hosts are merged and printed to the console.

4 changes: 2 additions & 2 deletions example/distributed_training/cifar10_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def __len__(self):


# Load the training data
train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True, transform=transform), batch_size,
train_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=True).transform(transform), batch_size,
sampler=SplitSampler(50000, store.num_workers, store.rank))

# Load the test data
test_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=False, transform=transform),
test_data = gluon.data.DataLoader(gluon.data.vision.CIFAR10(train=False).transform(transform),
batch_size, shuffle=False)

# Use ResNet from model zoo
Expand Down
8 changes: 4 additions & 4 deletions example/gluon/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='
train_dir = os.path.join(root, 'train')
train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
logging.info("Loading image folder %s, this may take a bit long...", train_dir)
train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
train_dataset = ImageFolderDataset(train_dir).transform_first(train_transform)
train_data = DataLoader(train_dataset, batch_size, shuffle=True,
last_batch='discard', num_workers=num_workers)
val_dir = os.path.join(root, 'val')
if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))):
user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1'
raise ValueError(user_warning)
logging.info("Loading image folder %s, this may take a bit long...", val_dir)
val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
val_dataset = ImageFolderDataset(val_dir).transform(val_transform)
val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)

Expand Down Expand Up @@ -118,8 +118,8 @@ def transform(image, label):
return transposed, label

training_path, testing_path = get_caltech101_data()
dataset_train = ImageFolderDataset(root=training_path, transform=transform)
dataset_test = ImageFolderDataset(root=testing_path, transform=transform)
dataset_train = ImageFolderDataset(root=training_path).transform(transform)
dataset_test = ImageFolderDataset(root=testing_path).transform(transform)

train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
Expand Down
8 changes: 4 additions & 4 deletions example/gluon/dc_gan/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,20 @@ def get_dataset(dataset_name):
# mnist
if dataset == "mnist":
train_data = gluon.data.DataLoader(
gluon.data.vision.MNIST('./data', train=True, transform=transformer),
gluon.data.vision.MNIST('./data', train=True).transform(transformer),
batch_size, shuffle=True, last_batch='discard')

val_data = gluon.data.DataLoader(
gluon.data.vision.MNIST('./data', train=False, transform=transformer),
gluon.data.vision.MNIST('./data', train=False).transform(transformer),
batch_size, shuffle=False)
# cifar10
elif dataset == "cifar10":
train_data = gluon.data.DataLoader(
gluon.data.vision.CIFAR10('./data', train=True, transform=transformer),
gluon.data.vision.CIFAR10('./data', train=True).transform(transformer),
batch_size, shuffle=True, last_batch='discard')

val_data = gluon.data.DataLoader(
gluon.data.vision.CIFAR10('./data', train=False, transform=transformer),
gluon.data.vision.CIFAR10('./data', train=False).transform(transformer),
batch_size, shuffle=False)

return train_data, val_data
Expand Down
4 changes: 2 additions & 2 deletions example/gluon/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def transformer(data, label):
return data, label

train_data = gluon.data.DataLoader(
gluon.data.vision.MNIST('./data', train=True, transform=transformer),
gluon.data.vision.MNIST('./data', train=True).transform(transformer),
batch_size=opt.batch_size, shuffle=True, last_batch='discard')

val_data = gluon.data.DataLoader(
gluon.data.vision.MNIST('./data', train=False, transform=transformer),
gluon.data.vision.MNIST('./data', train=False).transform(transformer),
batch_size=opt.batch_size, shuffle=False)

# train
Expand Down
2 changes: 1 addition & 1 deletion example/gluon/sn_gan/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ def transformer(data, label):
def get_training_data(batch_size):
""" helper function to get dataloader"""
return gluon.data.DataLoader(
CIFAR10(train=True, transform=transformer),
CIFAR10(train=True).transform(transformer),
batch_size=batch_size, shuffle=True, last_batch='discard')
6 changes: 3 additions & 3 deletions example/restricted-boltzmann-machine/binary_rbm_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def get_non_auxiliary_params(rbm):
def data_transform(data, label):
return data.astype(np.float32) / 255, label.astype(np.float32)

mnist_train_dataset = mx.gluon.data.vision.MNIST(train=True, transform=data_transform)
mnist_test_dataset = mx.gluon.data.vision.MNIST(train=False, transform=data_transform)
mnist_train_dataset = mx.gluon.data.vision.MNIST(train=True).transform(data_transform)
mnist_test_dataset = mx.gluon.data.vision.MNIST(train=False).transform(data_transform)
img_height = mnist_train_dataset[0][0].shape[0]
img_width = mnist_train_dataset[0][0].shape[1]
num_visible = img_width * img_height
Expand Down Expand Up @@ -139,4 +139,4 @@ def data_transform(data, label):
plt.axvline(showcase_num_samples_w * img_width, color='y')
plt.show(s)

print("Done")
print("Done")
4 changes: 4 additions & 0 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class _DownloadedDataset(Dataset):
"""Base class for MNIST, cifar10, etc."""
def __init__(self, root, transform):
super(_DownloadedDataset, self).__init__()
if transform is not None:
raise DeprecationWarning(
'Directly apply transform to dataset is deprecated. '
'Please use dataset.transform() or dataset.transform_first() instead...')
self._transform = transform
self._data = None
self._label = None
Expand Down
14 changes: 14 additions & 0 deletions python/mxnet/gluon/data/vision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MNIST(dataset._DownloadedDataset):
train : bool, default True
Whether to load the training or testing set.
transform : function, default None
DEPRECATED FUNCTION ARGUMENTS.
A user defined callback that transforms each sample. For example::

transform=lambda data, label: (data.astype(np.float32)/255, label)
Expand Down Expand Up @@ -110,6 +111,7 @@ class FashionMNIST(MNIST):
train : bool, default True
Whether to load the training or testing set.
transform : function, default None
DEPRECATED FUNCTION ARGUMENTS.
A user defined callback that transforms each sample. For example::

transform=lambda data, label: (data.astype(np.float32)/255, label)
Expand Down Expand Up @@ -142,6 +144,7 @@ class CIFAR10(dataset._DownloadedDataset):
train : bool, default True
Whether to load the training or testing set.
transform : function, default None
DEPRECATED FUNCTION ARGUMENTS.
A user defined callback that transforms each sample. For example::

transform=lambda data, label: (data.astype(np.float32)/255, label)
Expand Down Expand Up @@ -207,6 +210,7 @@ class CIFAR100(CIFAR10):
train : bool, default True
Whether to load the training or testing set.
transform : function, default None
DEPRECATED FUNCTION ARGUMENTS.
A user defined callback that transforms each sample. For example::

transform=lambda data, label: (data.astype(np.float32)/255, label)
Expand Down Expand Up @@ -243,13 +247,18 @@ class ImageRecordDataset(dataset.RecordFileDataset):
If 0, always convert images to greyscale. \
If 1, always convert images to colored (RGB).
transform : function, default None
DEPRECATED FUNCTION ARGUMENTS.
A user defined callback that transforms each sample. For example::

transform=lambda data, label: (data.astype(np.float32)/255, label)

"""
def __init__(self, filename, flag=1, transform=None):
super(ImageRecordDataset, self).__init__(filename)
if transform is not None:
raise DeprecationWarning(
'Directly apply transform to dataset is deprecated. '
'Please use dataset.transform() or dataset.transform_first() instead...')
self._flag = flag
self._transform = transform

Expand Down Expand Up @@ -281,6 +290,7 @@ class ImageFolderDataset(dataset.Dataset):
If 0, always convert loaded images to greyscale (1 channel).
If 1, always convert loaded images to colored (3 channels).
transform : callable, default None
DEPRECATED FUNCTION ARGUMENTS.
A function that takes data and label and transforms them::

transform = lambda data, label: (data.astype(np.float32)/255, label)
Expand All @@ -295,6 +305,10 @@ class ImageFolderDataset(dataset.Dataset):
def __init__(self, root, flag=1, transform=None):
self._root = os.path.expanduser(root)
self._flag = flag
if transform is not None:
raise DeprecationWarning(
'Directly apply transform to dataset is deprecated. '
'Please use dataset.transform() or dataset.transform_first() instead...')
self._transform = transform
self._exts = ['.jpg', '.jpeg', '.png']
self._list_images(self._root)
Expand Down