Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions R-package/demo/basic_kvstore.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
require(mxnet)

kv = mx.kv.create()

dlist = lapply(1:3, function(i) {
x = as.array(c(i, i+1))
mat = mx.nd.array(x, mx.cpu(i))
list(x=mat)
})
kv$init(c(0), dlist[[1]])
kv$push(c(0), dlist, 0)
olist = kv$pull(c(0), dlist, 0)

print(as.array(olist[[1]][[1]]))




1 change: 1 addition & 0 deletions R-package/src/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ inline std::vector<mx_uint> Dim2Vec(const Rcpp::Dimension &rshape) {
class NDArray;
class Symbol;
class Executor;
class KVStore;
} // namespace R
} // namespace mxnet

Expand Down
9 changes: 2 additions & 7 deletions R-package/src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,9 @@ Executor::RObjectType Executor::Backward(const RObjectType &exec,
<< "Expect exec to be " << Executor::TypeName();
RCHECK(Executor::XPtr(exec)->grad_arrays_ != nullptr)
<< "This executor has not been binded with req.grad";

Executor::RObjectType ret = Executor::Move(exec);
std::vector<NDArrayHandle> grad_handles(output_grads.size());
for (size_t i = 0; i < output_grads.size(); ++i) {
RCHECK(Rcpp::is<NDArray>(exec))
<< "Expect out_grads be list of " << NDArray::TypeName();
grad_handles[i] = NDArray::XPtr(output_grads[i])->handle();
}
std::vector<NDArrayHandle> grad_handles
= NDArray::GetHandles(output_grads, "output_grads", false);
MX_CALL(MXExecutorBackward(Executor::XPtr(ret)->handle_,
static_cast<mx_uint>(grad_handles.size()),
dmlc::BeginPtr(grad_handles)));
Expand Down
4 changes: 2 additions & 2 deletions R-package/src/io.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2015 by Contributors
* \file ndarray.cc
* \brief Rcpp NDArray of MXNet.
* \file io.cc
* \brief Rcpp IO module of mxnet.
*/
#include <Rcpp.h>
#include "./base.h"
Expand Down
109 changes: 109 additions & 0 deletions R-package/src/kvstore.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*!
* Copyright (c) 2015 by Contributors
* \file kvstore.cc
* \brief Rcpp NDArray of MXNet.
*/
#include <Rcpp.h>
#include "./base.h"
#include "./kvstore.h"
#include "./ndarray.h"

namespace mxnet {
namespace R {

void KVStore::Init(const std::vector<int>& keys, const Rcpp::List& weights) {
RCHECK(keys.size() == weights.size())
<< "The length of keys should be same as length of weights";
std::vector<NDArrayHandle> handles = NDArray::GetHandles(weights, "weights");
MX_CALL(MXKVStoreInit(
handle_, static_cast<mx_uint>(handles.size()),
dmlc::BeginPtr(keys), dmlc::BeginPtr(handles)));
}

void KVStore::Push(const std::vector<int>& keys,
const Rcpp::List& weight_lists,
const std::vector<int>& priority) {
RCHECK(keys.size() == priority.size() || priority.size() == 0)
<< "The length of keys should be same as length of priority";

std::vector<std::vector<NDArrayHandle> > vec(weight_lists.size());
for (size_t i = 0; i < weight_lists.size(); ++i) {
RCHECK(Rcpp::is<Rcpp::List>(weight_lists[i]))
<< "Expect weight_lists to be list(list(ndarray))";
Rcpp::List list = Rcpp::as<Rcpp::List>(weight_lists[i]);
RCHECK(list.size() == keys.size())
<< "Expect length of keys to be same as each weight_list";
vec[i] = NDArray::GetHandles(list, "weight_list");
}
// do push
std::vector<int> group_keys(vec.size());
std::vector<NDArrayHandle> vals(vec.size());
for (size_t i = 0; i < keys.size(); ++i) {
for (size_t j = 0; j < vec.size(); ++j) {
vals[j] = vec[j][i];
}
std::fill(group_keys.begin(), group_keys.end(), keys[i]);
MX_CALL(MXKVStorePush(handle_,
static_cast<mx_uint>(vals.size()),
dmlc::BeginPtr(group_keys),
dmlc::BeginPtr(vals),
priority.size() == 0 ? 0 : priority[i]));
}
}

Rcpp::List KVStore::Pull(const std::vector<int>& keys,
const Rcpp::List& out_lists,
const std::vector<int>& priority) {
RCHECK(keys.size() == priority.size() || priority.size() == 0)
<< "The length of keys should be same as length of priority";
Rcpp::List moved_list(out_lists.size());
std::vector<std::vector<NDArrayHandle> > vec(out_lists.size());
for (size_t i = 0; i < out_lists.size(); ++i) {
RCHECK(Rcpp::is<Rcpp::List>(out_lists[i]))
<< "Expect out_lists to be list(list(ndarray))";
Rcpp::List src = Rcpp::as<Rcpp::List>(out_lists[i]);
RCHECK(src.size() == keys.size())
<< "Expect length of keys to be same as each out_lists";
vec[i] = NDArray::GetHandles(src, "out_list");
Rcpp::List moved(src.size());
for (size_t j = 0; j < src.size(); ++j) {
moved[j] = NDArray::Move(src[j]);
}
moved_list[i] = moved;
}
// do pull
std::vector<int> group_keys(vec.size());
std::vector<NDArrayHandle> vals(vec.size());
for (size_t i = 0; i < keys.size(); ++i) {
for (size_t j = 0; j < vec.size(); ++j) {
vals[j] = vec[j][i];
}
std::fill(group_keys.begin(), group_keys.end(), keys[i]);
MX_CALL(MXKVStorePull(handle_, static_cast<mx_uint>(vals.size()),
dmlc::BeginPtr(group_keys),
dmlc::BeginPtr(vals),
priority.size() == 0 ? 0 : priority[i]));
}
return moved_list;
}

Rcpp::RObject KVStore::Create(const char *type) {
KVStoreHandle handle;
MX_CALL(MXKVStoreCreate(type, &handle));
return Rcpp::internal::make_new_object(new KVStore(handle));
}

void KVStore::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
class_<KVStore>("MXKVStore")
.finalizer(&KVStore::Finalizer)
.method("init", &KVStore::Init)
.method("push", &KVStore::Push)
.method("pull", &KVStore::Pull);

function("mx.kv.create", &KVStore::Create,
List::create(_["type"] = "local"),
"Create a new kvstore");
}
} // namespace R
} // namespace mxnet
76 changes: 76 additions & 0 deletions R-package/src/kvstore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*!
* Copyright (c) 2015 by Contributors
* \file kvstore.h
* \brief Rcpp Parameter Store interface of MXNet
*/
#ifndef MXNET_RCPP_KVSTORE_H_
#define MXNET_RCPP_KVSTORE_H_

#include <Rcpp.h>
#include <mxnet/c_api.h>
#include <string>
#include <vector>
#include "./base.h"

namespace mxnet {
namespace R {

/*!
* \brief MXNet's Parameter store interface.
*/
class KVStore {
public:
/*!
* \brief initialize all the weights
* \param keys The keys of each weight.
* \param weights the weights NDArray list.
*/
void Init(const std::vector<int>& keys, const Rcpp::List& weights);
/*!
* \brief Push the weights to the KVStore.
*
* This operation will do a aggregation first on weight_lists, the push things out.
*
* sum_list[i] = sum(list[i] for list in weight_lists)
* Then push(keys[i], sum_list[i]) for each i.
*
* \param keys list of keys, corresponds to key of each location.
* \param weight_lists List of Rcpp::List.
* \param priority The priority of each key.
*/
void Push(const std::vector<int>& keys,
const Rcpp::List& weight_lists,
const std::vector<int>& priority);
/*!
* \brief Pull the data back.
*
* \param keys List of keys, corresponds to key of each location.
* \param out_lists List of Rcpp::List
* The list of NDArrays to hold the result, this will be moved.
* \param priority The priority of each key.
* \return The result list of pull.
*/
Rcpp::List Pull(const std::vector<int>& keys,
const Rcpp::List& out_lists,
const std::vector<int>& priority);
/*!
* \brief create a KVStore
* \return the created KVStore
*/
static Rcpp::RObject Create(const char *type);
/*! \brief initialize the R cpp Module */
static void InitRcppModule();

private:
explicit KVStore(KVStoreHandle handle)
: handle_(handle) {}
static void Finalizer(KVStore *kv) {
MX_CALL(MXKVStoreFree(kv->handle_));
}
/*! \brief internal KVStore handle */
KVStoreHandle handle_;
};

} // namespace R
} // namespace mxnet
#endif // MXNET_RCPP_KVSTORE_H_
2 changes: 2 additions & 0 deletions R-package/src/mxnet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "./symbol.h"
#include "./executor.h"
#include "./io.h"
#include "./kvstore.h"

RCPP_MODULE(mxnet) {
using namespace mxnet::R;
Expand All @@ -20,4 +21,5 @@ RCPP_MODULE(mxnet) {
Executor::InitRcppModule();
DataIter::InitRcppModule();
DataIterCreateFunction::InitRcppModule();
KVStore::InitRcppModule();
}
22 changes: 21 additions & 1 deletion R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,24 @@ NDArray::RObjectType NDArray::Empty(
return NDArray::RObject(handle);
}

std::vector<NDArrayHandle> NDArray::GetHandles(const Rcpp::List& array_list,
const std::string& list_name,
bool allow_null) {
std::vector<NDArrayHandle> ret(array_list.size());
for (size_t i = 0; i < ret.size(); ++i) {
if (array_list[i] == R_NilValue) {
RCHECK(allow_null)
<< "Expect " << list_name << " to be list of non-NULL " << NDArray::TypeName();
ret[i] = nullptr;
} else {
RCHECK(Rcpp::is<NDArray>(array_list[i]))
<< "Expect " << list_name << " to be list of " << NDArray::TypeName();
ret[i] = NDArray::XPtr(array_list[i])->handle_;
}
}
return ret;
}

NDArray::RObjectType NDArray::Clone() const {
std::vector<mx_uint> shape = Dim2Vec(this->shape());
NDArrayHandle handle;
Expand Down Expand Up @@ -384,7 +402,9 @@ void NDArray::InitRcppModule() {
function("mx.nd.internal.load", &NDArray::Load);
function("mx.nd.internal.save", &NDArray::Save);
function("mx.nd.internal.empty", &NDArray::Empty);
function("mx.nd.array", &NDArray::Array);
function("mx.nd.array", &NDArray::Array,
List::create(_["src"], _["ctx"]),
"Create a new mx.ndarray that copies the content from src on ctx.");
}

void NDArrayFunction::InitRcppModule() {
Expand Down
11 changes: 11 additions & 0 deletions R-package/src/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <mxnet/c_api.h>
#include <string>
#include <algorithm>
#include <vector>

namespace mxnet {
namespace R {
Expand Down Expand Up @@ -71,6 +72,15 @@ class NDArray : public MXNetMovable<NDArray> {
*/
static RObjectType Array(const Rcpp::RObject& src,
const Context::RObjectType& ctx);
/*!
* \brief Extract NDArrayHandles from List.
* \param array_list The NDArray list.
* \param list_name The name of the list, used for error message.
* \param allow_null If set to True, allow null in the list.
*/
static std::vector<NDArrayHandle> GetHandles(const Rcpp::List& array_list,
const std::string& list_name,
bool allow_null = false);
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();

Expand All @@ -91,6 +101,7 @@ class NDArray : public MXNetMovable<NDArray> {
private:
// declare friend class
friend class NDArrayFunction;
friend class KVStore;
friend class Executor;
friend class MXNetMovable<NDArray>;
// enable trivial operator= etc.
Expand Down
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ MXNET_DLL int MXKVStoreFree(KVStoreHandle handle);
*/
MXNET_DLL int MXKVStoreInit(KVStoreHandle handle,
mx_uint num,
int* keys,
const int* keys,
NDArrayHandle* vals);

/*!
Expand All @@ -749,7 +749,7 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle,
*/
MXNET_DLL int MXKVStorePush(KVStoreHandle handle,
mx_uint num,
int* keys,
const int* keys,
NDArrayHandle* vals,
int priority);
/*!
Expand All @@ -763,7 +763,7 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle,
*/
MXNET_DLL int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
int* keys,
const int* keys,
NDArrayHandle* vals,
int priority);
/*!
Expand Down
6 changes: 3 additions & 3 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ int MXKVStoreFree(KVStoreHandle handle) {

int MXKVStoreInit(KVStoreHandle handle,
mx_uint num,
int* keys,
const int* keys,
NDArrayHandle* vals) {
API_BEGIN();
std::vector<int> v_keys(num);
Expand All @@ -966,7 +966,7 @@ int MXKVStoreInit(KVStoreHandle handle,

int MXKVStorePush(KVStoreHandle handle,
mx_uint num,
int* keys,
const int* keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
Expand All @@ -982,7 +982,7 @@ int MXKVStorePush(KVStoreHandle handle,

int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
int* keys,
const int* keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
Expand Down