diff --git a/R-package/demo/basic_kvstore.R b/R-package/demo/basic_kvstore.R new file mode 100644 index 000000000000..1a9ba411aa03 --- /dev/null +++ b/R-package/demo/basic_kvstore.R @@ -0,0 +1,18 @@ +require(mxnet) + +kv = mx.kv.create() + +dlist = lapply(1:3, function(i) { + x = as.array(c(i, i+1)) + mat = mx.nd.array(x, mx.cpu(i)) + list(x=mat) +}) +kv$init(c(0), dlist[[1]]) +kv$push(c(0), dlist, 0) +olist = kv$pull(c(0), dlist, 0) + +print(as.array(olist[[1]][[1]])) + + + + diff --git a/R-package/src/base.h b/R-package/src/base.h index 068b090127d2..ca8ceaf3fef2 100644 --- a/R-package/src/base.h +++ b/R-package/src/base.h @@ -357,6 +357,7 @@ inline std::vector Dim2Vec(const Rcpp::Dimension &rshape) { class NDArray; class Symbol; class Executor; +class KVStore; } // namespace R } // namespace mxnet diff --git a/R-package/src/executor.cc b/R-package/src/executor.cc index 4422341229a8..c65b96719f0a 100644 --- a/R-package/src/executor.cc +++ b/R-package/src/executor.cc @@ -77,14 +77,9 @@ Executor::RObjectType Executor::Backward(const RObjectType &exec, << "Expect exec to be " << Executor::TypeName(); RCHECK(Executor::XPtr(exec)->grad_arrays_ != nullptr) << "This executor has not been binded with req.grad"; - Executor::RObjectType ret = Executor::Move(exec); - std::vector grad_handles(output_grads.size()); - for (size_t i = 0; i < output_grads.size(); ++i) { - RCHECK(Rcpp::is(exec)) - << "Expect out_grads be list of " << NDArray::TypeName(); - grad_handles[i] = NDArray::XPtr(output_grads[i])->handle(); - } + std::vector grad_handles + = NDArray::GetHandles(output_grads, "output_grads", false); MX_CALL(MXExecutorBackward(Executor::XPtr(ret)->handle_, static_cast(grad_handles.size()), dmlc::BeginPtr(grad_handles))); diff --git a/R-package/src/io.cc b/R-package/src/io.cc index 40fea33131b7..0aa2d3a6ce5e 100644 --- a/R-package/src/io.cc +++ b/R-package/src/io.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2015 by Contributors - * \file ndarray.cc - * \brief Rcpp NDArray of MXNet. + * \file io.cc + * \brief Rcpp IO module of mxnet. */ #include #include "./base.h" diff --git a/R-package/src/kvstore.cc b/R-package/src/kvstore.cc new file mode 100644 index 000000000000..53e8cb852c84 --- /dev/null +++ b/R-package/src/kvstore.cc @@ -0,0 +1,109 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file kvstore.cc + * \brief Rcpp NDArray of MXNet. + */ +#include +#include "./base.h" +#include "./kvstore.h" +#include "./ndarray.h" + +namespace mxnet { +namespace R { + +void KVStore::Init(const std::vector& keys, const Rcpp::List& weights) { + RCHECK(keys.size() == weights.size()) + << "The length of keys should be same as length of weights"; + std::vector handles = NDArray::GetHandles(weights, "weights"); + MX_CALL(MXKVStoreInit( + handle_, static_cast(handles.size()), + dmlc::BeginPtr(keys), dmlc::BeginPtr(handles))); +} + +void KVStore::Push(const std::vector& keys, + const Rcpp::List& weight_lists, + const std::vector& priority) { + RCHECK(keys.size() == priority.size() || priority.size() == 0) + << "The length of keys should be same as length of priority"; + + std::vector > vec(weight_lists.size()); + for (size_t i = 0; i < weight_lists.size(); ++i) { + RCHECK(Rcpp::is(weight_lists[i])) + << "Expect weight_lists to be list(list(ndarray))"; + Rcpp::List list = Rcpp::as(weight_lists[i]); + RCHECK(list.size() == keys.size()) + << "Expect length of keys to be same as each weight_list"; + vec[i] = NDArray::GetHandles(list, "weight_list"); + } + // do push + std::vector group_keys(vec.size()); + std::vector vals(vec.size()); + for (size_t i = 0; i < keys.size(); ++i) { + for (size_t j = 0; j < vec.size(); ++j) { + vals[j] = vec[j][i]; + } + std::fill(group_keys.begin(), group_keys.end(), keys[i]); + MX_CALL(MXKVStorePush(handle_, + static_cast(vals.size()), + dmlc::BeginPtr(group_keys), + dmlc::BeginPtr(vals), + priority.size() == 0 ? 0 : priority[i])); + } +} + +Rcpp::List KVStore::Pull(const std::vector& keys, + const Rcpp::List& out_lists, + const std::vector& priority) { + RCHECK(keys.size() == priority.size() || priority.size() == 0) + << "The length of keys should be same as length of priority"; + Rcpp::List moved_list(out_lists.size()); + std::vector > vec(out_lists.size()); + for (size_t i = 0; i < out_lists.size(); ++i) { + RCHECK(Rcpp::is(out_lists[i])) + << "Expect out_lists to be list(list(ndarray))"; + Rcpp::List src = Rcpp::as(out_lists[i]); + RCHECK(src.size() == keys.size()) + << "Expect length of keys to be same as each out_lists"; + vec[i] = NDArray::GetHandles(src, "out_list"); + Rcpp::List moved(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + moved[j] = NDArray::Move(src[j]); + } + moved_list[i] = moved; + } + // do pull + std::vector group_keys(vec.size()); + std::vector vals(vec.size()); + for (size_t i = 0; i < keys.size(); ++i) { + for (size_t j = 0; j < vec.size(); ++j) { + vals[j] = vec[j][i]; + } + std::fill(group_keys.begin(), group_keys.end(), keys[i]); + MX_CALL(MXKVStorePull(handle_, static_cast(vals.size()), + dmlc::BeginPtr(group_keys), + dmlc::BeginPtr(vals), + priority.size() == 0 ? 0 : priority[i])); + } + return moved_list; +} + +Rcpp::RObject KVStore::Create(const char *type) { + KVStoreHandle handle; + MX_CALL(MXKVStoreCreate(type, &handle)); + return Rcpp::internal::make_new_object(new KVStore(handle)); +} + +void KVStore::InitRcppModule() { + using namespace Rcpp; // NOLINT(*) + class_("MXKVStore") + .finalizer(&KVStore::Finalizer) + .method("init", &KVStore::Init) + .method("push", &KVStore::Push) + .method("pull", &KVStore::Pull); + + function("mx.kv.create", &KVStore::Create, + List::create(_["type"] = "local"), + "Create a new kvstore"); +} +} // namespace R +} // namespace mxnet diff --git a/R-package/src/kvstore.h b/R-package/src/kvstore.h new file mode 100644 index 000000000000..ed1a4d0857f1 --- /dev/null +++ b/R-package/src/kvstore.h @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file kvstore.h + * \brief Rcpp Parameter Store interface of MXNet + */ +#ifndef MXNET_RCPP_KVSTORE_H_ +#define MXNET_RCPP_KVSTORE_H_ + +#include +#include +#include +#include +#include "./base.h" + +namespace mxnet { +namespace R { + +/*! + * \brief MXNet's Parameter store interface. + */ +class KVStore { + public: + /*! + * \brief initialize all the weights + * \param keys The keys of each weight. + * \param weights the weights NDArray list. + */ + void Init(const std::vector& keys, const Rcpp::List& weights); + /*! + * \brief Push the weights to the KVStore. + * + * This operation will do a aggregation first on weight_lists, the push things out. + * + * sum_list[i] = sum(list[i] for list in weight_lists) + * Then push(keys[i], sum_list[i]) for each i. + * + * \param keys list of keys, corresponds to key of each location. + * \param weight_lists List of Rcpp::List. + * \param priority The priority of each key. + */ + void Push(const std::vector& keys, + const Rcpp::List& weight_lists, + const std::vector& priority); + /*! + * \brief Pull the data back. + * + * \param keys List of keys, corresponds to key of each location. + * \param out_lists List of Rcpp::List + * The list of NDArrays to hold the result, this will be moved. + * \param priority The priority of each key. + * \return The result list of pull. + */ + Rcpp::List Pull(const std::vector& keys, + const Rcpp::List& out_lists, + const std::vector& priority); + /*! + * \brief create a KVStore + * \return the created KVStore + */ + static Rcpp::RObject Create(const char *type); + /*! \brief initialize the R cpp Module */ + static void InitRcppModule(); + + private: + explicit KVStore(KVStoreHandle handle) + : handle_(handle) {} + static void Finalizer(KVStore *kv) { + MX_CALL(MXKVStoreFree(kv->handle_)); + } + /*! \brief internal KVStore handle */ + KVStoreHandle handle_; +}; + +} // namespace R +} // namespace mxnet +#endif // MXNET_RCPP_KVSTORE_H_ diff --git a/R-package/src/mxnet.cc b/R-package/src/mxnet.cc index 311b3b6d3b07..183a8af92167 100644 --- a/R-package/src/mxnet.cc +++ b/R-package/src/mxnet.cc @@ -9,6 +9,7 @@ #include "./symbol.h" #include "./executor.h" #include "./io.h" +#include "./kvstore.h" RCPP_MODULE(mxnet) { using namespace mxnet::R; @@ -20,4 +21,5 @@ RCPP_MODULE(mxnet) { Executor::InitRcppModule(); DataIter::InitRcppModule(); DataIterCreateFunction::InitRcppModule(); + KVStore::InitRcppModule(); } diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index 2c6c07982d2c..503eb59bce6a 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -190,6 +190,24 @@ NDArray::RObjectType NDArray::Empty( return NDArray::RObject(handle); } +std::vector NDArray::GetHandles(const Rcpp::List& array_list, + const std::string& list_name, + bool allow_null) { + std::vector ret(array_list.size()); + for (size_t i = 0; i < ret.size(); ++i) { + if (array_list[i] == R_NilValue) { + RCHECK(allow_null) + << "Expect " << list_name << " to be list of non-NULL " << NDArray::TypeName(); + ret[i] = nullptr; + } else { + RCHECK(Rcpp::is(array_list[i])) + << "Expect " << list_name << " to be list of " << NDArray::TypeName(); + ret[i] = NDArray::XPtr(array_list[i])->handle_; + } + } + return ret; +} + NDArray::RObjectType NDArray::Clone() const { std::vector shape = Dim2Vec(this->shape()); NDArrayHandle handle; @@ -384,7 +402,9 @@ void NDArray::InitRcppModule() { function("mx.nd.internal.load", &NDArray::Load); function("mx.nd.internal.save", &NDArray::Save); function("mx.nd.internal.empty", &NDArray::Empty); - function("mx.nd.array", &NDArray::Array); + function("mx.nd.array", &NDArray::Array, + List::create(_["src"], _["ctx"]), + "Create a new mx.ndarray that copies the content from src on ctx."); } void NDArrayFunction::InitRcppModule() { diff --git a/R-package/src/ndarray.h b/R-package/src/ndarray.h index 1503c297815f..5a31e607c19f 100644 --- a/R-package/src/ndarray.h +++ b/R-package/src/ndarray.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace mxnet { namespace R { @@ -71,6 +72,15 @@ class NDArray : public MXNetMovable { */ static RObjectType Array(const Rcpp::RObject& src, const Context::RObjectType& ctx); + /*! + * \brief Extract NDArrayHandles from List. + * \param array_list The NDArray list. + * \param list_name The name of the list, used for error message. + * \param allow_null If set to True, allow null in the list. + */ + static std::vector GetHandles(const Rcpp::List& array_list, + const std::string& list_name, + bool allow_null = false); /*! \brief static function to initialize the Rcpp functions */ static void InitRcppModule(); @@ -91,6 +101,7 @@ class NDArray : public MXNetMovable { private: // declare friend class friend class NDArrayFunction; + friend class KVStore; friend class Executor; friend class MXNetMovable; // enable trivial operator= etc. diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d22e02a20666..b7c22ee9fb65 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -735,7 +735,7 @@ MXNET_DLL int MXKVStoreFree(KVStoreHandle handle); */ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, mx_uint num, - int* keys, + const int* keys, NDArrayHandle* vals); /*! @@ -749,7 +749,7 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, */ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, mx_uint num, - int* keys, + const int* keys, NDArrayHandle* vals, int priority); /*! @@ -763,7 +763,7 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, */ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, mx_uint num, - int* keys, + const int* keys, NDArrayHandle* vals, int priority); /*! diff --git a/src/c_api.cc b/src/c_api.cc index 026a293cdb02..1a12fe65ea3d 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -951,7 +951,7 @@ int MXKVStoreFree(KVStoreHandle handle) { int MXKVStoreInit(KVStoreHandle handle, mx_uint num, - int* keys, + const int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); @@ -966,7 +966,7 @@ int MXKVStoreInit(KVStoreHandle handle, int MXKVStorePush(KVStoreHandle handle, mx_uint num, - int* keys, + const int* keys, NDArrayHandle* vals, int priority) { API_BEGIN(); @@ -982,7 +982,7 @@ int MXKVStorePush(KVStoreHandle handle, int MXKVStorePull(KVStoreHandle handle, mx_uint num, - int* keys, + const int* keys, NDArrayHandle* vals, int priority) { API_BEGIN();