Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ typedef void *ExecutorHandle;
typedef void *DataIterCreator;
/*! \brief handle to a DataIterator */
typedef void *DataIterHandle;
/*! \brief handle a dataset creator */
typedef void *DatasetCreator;
/*! \brief handle to a Dataset */
typedef void *DatasetHandle;
/*! \brief handle to a BatchifyFunction creator*/
typedef void *BatchifyFunctionCreator;
/*! \brief handle to a BatchifyFunction */
typedef void *BatchifyFunctionHandle;
/*! \brief handle to KVStore */
typedef void *KVStoreHandle;
/*! \brief handle to RecordIO */
Expand Down Expand Up @@ -2670,6 +2678,13 @@ MXNET_DLL int MXDataIterNext(DataIterHandle handle,
*/
MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);

/*!
* \brief Call iterator.GetLenHint. Note that some iterators don't provide length.
* \param handle the handle to iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetLenHint(DataIterHandle handle,
int64_t *len);
/*!
* \brief Get the handle to the NDArray of underlying data
* \param handle the handle pointer to the data iterator
Expand Down Expand Up @@ -2705,6 +2720,147 @@ MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle,
*/
MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
NDArrayHandle *out);
/*!
* \brief Get the handles to specified underlying ndarrays of index
* \param handle the handle pointer to the data iterator
* \param num_outputs the length of outputs
* \param out the handle to an array of NDArrays that stores pointers to handles
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetItems(DataIterHandle handle,
int* num_outputs,
NDArrayHandle **outputs);

/*!
* \brief List all the available dataset entries
* \param out_size the size of returned datasets
* \param out_array the output dataset entries
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListDatasets(uint32_t *out_size,
DatasetCreator **out_array);
/*!
* \brief Init an dataset, init with parameters
* the array size of passed in arguments
* \param handle of the dataset creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting dataset
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetCreateDataset(DatasetCreator handle,
uint32_t num_param,
const char **keys,
const char **vals,
DatasetHandle *out);
/*!
* \brief Get the detailed information about dataset.
* \param creator the DatasetCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetGetDatasetInfo(DatasetCreator creator,
const char **name,
const char **description,
uint32_t *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief Free the handle to the IO module
* \param handle the handle pointer to the dataset
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetFree(DatasetHandle handle);
/*!
* \brief Get dataset overal length(size)
* \param handle the handle to dataset
* \param out return value of GetLen
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetGetLen(DatasetHandle handle,
uint64_t *out);
/*!
* \brief Get Output NDArray given specified indices
* \param handle the handle to dataset
* \param index the index of the dataset item to be retrieved
* \param num_outputs the number of output ndarrays
* \param outputs the pointers to handles of ndarrays
* \param is_scalar if not zeros then output should be casted to scalars
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetGetItems(DatasetHandle handle,
uint64_t index,
int* num_outputs,
NDArrayHandle **outputs);

/*!
* \brief List all the available batchify function entries
* \param out_size the size of returned batchify functions
* \param out_array the output batchify function entries
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListBatchifyFunctions(uint32_t *out_size,
BatchifyFunctionCreator **out_array);
/*!
* \brief Init an batchify function, init with parameters
* the array size of passed in arguments
* \param handle of the batchify function creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting batchify function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionCreateFunction(BatchifyFunctionCreator handle,
uint32_t num_param,
const char **keys,
const char **vals,
BatchifyFunctionHandle *out);
/*!
* \brief Get the detailed information about batchify function.
* \param creator the batchifyFunctionCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionGetFunctionInfo(BatchifyFunctionCreator creator,
const char **name,
const char **description,
uint32_t *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief Invoke the Batchify Function
* \param handle the handle pointer to the batchify function
* \param batch_size the batch size
* \param num_output the number of ndarrays for output
* \param inputs the pointers to input ndarrays
* \param ouptuts the pointers to output ndarrays
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionInvoke(BatchifyFunctionHandle handle,
int batch_size,
int num_output,
NDArrayHandle *inputs,
NDArrayHandle **outputs);
/*!
* \brief Free the handle to the IO module
* \param handle the handle pointer to the batchify function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionFree(BatchifyFunctionHandle handle);
//--------------------------------------------
// Part 6: basic KVStore interface
//--------------------------------------------
Expand Down
98 changes: 97 additions & 1 deletion include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class IIterator : public dmlc::DataIter<DType> {
inline void SetDataName(const std::string data_name) {
data_names.push_back(data_name);
}
/*! \brief request iterator length hint for current epoch.
* Note that the returned value can be < 0, indicating
* that the length of iterator is unknown unless you went through all data.
*/
virtual int64_t GetLenHint(void) const {
return -1;
}
}; // class IIterator

/*! \brief a single data instance */
Expand Down Expand Up @@ -104,7 +111,7 @@ struct DataIteratorReg
*
* \code
* // example of registering a mnist iterator
* REGISTER_IO_ITE(MNISTIter)
* REGISTER_IO_ITER(MNISTIter)
* .describe("Mnist data iterator")
* .set_body([]() {
* return new PrefetcherIter(new MNISTIter());
Expand All @@ -113,5 +120,94 @@ struct DataIteratorReg
*/
#define MXNET_REGISTER_IO_ITER(name) \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name)

/*!
* \brief A random accessable dataset which provides GetLen() and GetItem().
* Unlike DataIter, it's a static lookup storage which is friendly to random access.
* The dataset itself should NOT contain data processing, which should be applied during
* data augmentation or transformation processes.
*/
class Dataset {
public:
/*!
* \brief Get the size of the dataset
*/
virtual uint64_t GetLen(void) const = 0;
/*!
* \brief Get the ndarray items given index in dataset
* \param idx the integer index for required data
* \param ret the returned ndarray items
*/
virtual bool GetItem(uint64_t idx, std::vector<NDArray>* ret) = 0;
// virtual destructor
virtual ~Dataset(void) {}
}; // class Dataset

/*! \brief typedef the factory function of dataset */
typedef std::function<Dataset *(
const std::vector<std::pair<std::string, std::string> >&)> DatasetFactory;
/*!
* \brief Registry entry for Dataset factory functions.
*/
struct DatasetReg
: public dmlc::FunctionRegEntryBase<DatasetReg,
DatasetFactory> {
};
//--------------------------------------------------------------
// The following part are API Registration of Datasets
//--------------------------------------------------------------
/*!
* \brief Macro to register Datasets
*
* \code
* // example of registering an image sequence dataset
* REGISTER_IO_ITE(ImageSequenceDataset)
* .describe("image sequence dataset")
* .set_body([]() {
* return new ImageSequenceDataset();
* });
* \endcode
*/
#define MXNET_REGISTER_IO_DATASET(name) \
DMLC_REGISTRY_REGISTER(::mxnet::DatasetReg, DatasetReg, name)

class BatchifyFunction {
public:
/*! \brief Destructor */
virtual ~BatchifyFunction(void) {}
/*! \brief The batchify logic */
virtual bool Batchify(const std::vector<std::vector<NDArray> >& inputs,
std::vector<NDArray>* outputs) = 0;
}; // class BatchifyFunction

using BatchifyFunctionPtr = std::shared_ptr<BatchifyFunction>;

/*! \brief typedef the factory function of data sampler */
typedef std::function<BatchifyFunction *(
const std::vector<std::pair<std::string, std::string> >&)> BatchifyFunctionFactory;
/*!
* \brief Registry entry for DataSampler factory functions.
*/
struct BatchifyFunctionReg
: public dmlc::FunctionRegEntryBase<BatchifyFunctionReg,
BatchifyFunctionFactory> {
};
//--------------------------------------------------------------
// The following part are API Registration of Batchify Function
//--------------------------------------------------------------
/*!
* \brief Macro to register Batchify Functions
*
* \code
* // example of registering a Batchify Function
* MXNET_REGISTER_IO_BATCHIFY_FUNCTION(StackBatchify)
* .describe("Stack Batchify Function")
* .set_body([]() {
* return new StackBatchify();
* });
* \endcode
*/
#define MXNET_REGISTER_IO_BATCHIFY_FUNCTION(name) \
DMLC_REGISTRY_REGISTER(::mxnet::BatchifyFunctionReg, BatchifyFunctionReg, name)
} // namespace mxnet
#endif // MXNET_IO_H_
2 changes: 2 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def _load_lib():
ExecutorHandle = ctypes.c_void_p
DataIterCreatorHandle = ctypes.c_void_p
DataIterHandle = ctypes.c_void_p
DatasetHandle = ctypes.c_void_p
BatchifyFunctionhandle = ctypes.c_void_p
KVStoreHandle = ctypes.c_void_p
RecordIOHandle = ctypes.c_void_p
RtcHandle = ctypes.c_void_p
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/gluon/contrib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
"""Contrib datasets."""

from . import text
from . import vision

from .sampler import *
22 changes: 22 additions & 0 deletions python/mxnet/gluon/contrib/data/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Contrib vision utilities."""
from .transforms import *
from .dataloader import *
Loading